Arena_KeyDub / app.py
toninio19's picture
Update app.py
cebd7b7 verified
import gradio as gr
import os
import random
import json
from typing import List, Dict, Tuple
from datetime import datetime
import pandas as pd
from filelock import FileLock, Timeout
from pathlib import Path
import tempfile
import argparse
class VideoArenaManager:
def __init__(self, base_dir: str = "videos", data_dir: str = "/data"):
self.base_dir = base_dir
self.data_dir = data_dir
self.models = self._load_models()
self.videos = self._load_videos()
self.data_file = os.path.join(self.data_dir, "arena_data_new.json")
self.data_lock = FileLock(os.path.join(self.data_dir, "arena_data_new.lock"))
self.data = self._load_data()
self.usernames = set()
def _load_models(self) -> List[Dict[str, str]]:
"""Load available models from directories."""
return [
{"name": "Model 1", "directory": "model_1"},
{"name": "Model 2", "directory": "model_2"},
{"name": "Model 3", "directory": "model_3"},
{"name": "Model 4", "directory": "model_4"},
{"name": "Model 5", "directory": "model_5"},
{"name": "Model 6", "directory": "model_6"},
# Add more models as needed
]
def get_data_files(self) -> List[Dict[str, str]]:
"""Retrieve a list of data files in the data directory."""
data_path = Path(self.data_dir)
files = [
{"name": file.name, "size": file.stat().st_size, "path": str(file)}
for file in data_path.glob("*")
if file.is_file()
]
return files
def read_data_file(self, file_path: str) -> bytes:
"""Read the content of a data file."""
with open(file_path, "rb") as f:
return f.read()
def _load_videos(self) -> List[str]:
"""Load available video files."""
base_path = os.path.join(self.base_dir, self.models[0]["directory"])
return [f for f in os.listdir(base_path) if f.endswith((".mp4", ".avi", ".mov"))]
def _load_data(self) -> Dict:
"""Load existing Elo ratings and comparison results."""
try:
with self.data_lock.acquire(timeout=10):
if os.path.exists(self.data_file):
with open(self.data_file, "r") as f:
return json.load(f)
else:
# Initialize data if file does not exist
default_rating = 1000.0
elo_ratings = {model["name"]: default_rating for model in self.models}
results = {"comparisons": []}
data = {"elo_ratings": elo_ratings, "results": results}
with open(self.data_file, "w") as f:
json.dump(data, f, indent=2)
return data
except (Timeout, KeyError):
print(f"Could not acquire lock on {self.data_file}")
# Handle timeout (e.g., return default data or raise an error)
default_rating = 1000.0
elo_ratings = {model["name"]: default_rating for model in self.models}
results = {"comparisons": []}
return {"elo_ratings": elo_ratings, "results": results}
def save_comparison(self, video_name: str, winner: str, loser: str, username: str):
"""Save a comparison result and update Elo ratings."""
try:
with self.data_lock.acquire(timeout=10):
# Reload data to get the latest information
if os.path.exists(self.data_file):
with open(self.data_file, "r") as f:
self.data = json.load(f)
else:
# Initialize data if file does not exist
default_rating = 1000.0
elo_ratings = {model["name"]: default_rating for model in self.models}
results = {"comparisons": []}
self.data = {"elo_ratings": elo_ratings, "results": results}
# Update comparison results with username
comparison = {
"timestamp": datetime.now().isoformat(),
"video": video_name,
"winner": winner,
"loser": loser,
"username": username, # Add username to comparison data
}
self.data["results"]["comparisons"].append(comparison)
# Update Elo ratings
self.update_elo_ratings(winner, loser)
# Save updated data
with open(self.data_file, "w") as f:
json.dump(self.data, f, indent=2)
except Timeout:
print(f"Could not acquire lock on {self.data_file}")
# Handle timeout (e.g., retry or raise an error)
def update_elo_ratings(self, winner: str, loser: str, k: float = 32):
"""Update the Elo ratings of the models."""
# Assume the lock on data_file is already held
elo_ratings = self.data["elo_ratings"]
winner_rating = elo_ratings[winner]
loser_rating = elo_ratings[loser]
# Calculate expected scores
expected_winner = 1 / (1 + 10 ** ((loser_rating - winner_rating) / 400))
# Update ratings
elo_ratings[winner] += k * (1 - expected_winner)
elo_ratings[loser] += k * (0 - (1 - expected_winner))
# Update the data
self.data["elo_ratings"] = elo_ratings
def get_random_pair(self) -> Tuple[Dict[str, str], Dict[str, str]]:
"""Get a random pair of models to compare."""
return tuple(random.sample(self.models, 2))
def get_rankings(self) -> pd.DataFrame:
"""Generate current rankings based on Elo ratings."""
elo_ratings = self.data["elo_ratings"]
rankings = [{"Model": model["name"], "Elo Rating": elo_ratings[model["name"]]} for model in self.models]
df = pd.DataFrame(rankings)
return df.sort_values(by="Elo Rating", ascending=False).reset_index(drop=True)
def get_video_paths(self, video_name: str, model_pair: Tuple[Dict[str, str]]) -> List[str]:
"""Get video paths for the given pair of models."""
return [os.path.join(self.base_dir, model["directory"], video_name) for model in model_pair]
def generate_username(self) -> str:
"""Generate a unique random username."""
adjectives = ["Happy", "Quick", "Clever", "Brave", "Wise", "Kind", "Swift"]
animals = ["Panda", "Tiger", "Eagle", "Dolphin", "Fox", "Owl", "Bear"]
while True:
username = f"{random.choice(adjectives)}{random.choice(animals)}{random.randint(100, 999)}"
if username not in self.usernames:
self.usernames.add(username)
return username
def create_arena_interface(data_dir: str = None):
# Instantiate the manager with the provided data_dir if set
if data_dir is not None:
manager = VideoArenaManager(data_dir=data_dir)
else:
manager = VideoArenaManager()
with gr.Blocks(
title="Video Model Ranking Arena",
css="""
.invisible-textbox {
position: fixed;
opacity: 0.1;
pointer-events: auto;
width: 100px;
height: 20px;
padding: 5px;
margin: 5px;
background: #f0f0f0;
border: 1px solid #ddd;
}
""",
) as demo:
gr.Markdown(
"""### Welcome to the Dubbing Evaluation Arena!
In this study, the models modify only the lip region of the characters to better match the new dubbed audio, while the rest of the video remains unchanged.
Please compare the two videos and vote for the one you prefer based on the following criteria:
- **Lip Synchronization with Audio**: How well the character's lip movements align with the new speech.
- **Overall Coherence**: How seamlessly the modified lip movements integrate with the rest of the video.
- **Image Quality**: Clarity and visual appeal of the video.
Select either the left or right video as your preference. Thank you for your feedback!
(**Note**: If you are on a mobile phone, try turning the screen landscape for a better experience)"""
)
# State variables
current_video = gr.State()
current_models = gr.State()
# Add username state
username = gr.State(manager.generate_username())
with gr.Row():
# Display two videos side by side
video_left = gr.Video(label="Model A", height=400)
video_right = gr.Video(label="Model B", height=400)
with gr.Row():
# Buttons for voting
left_button = gr.Button("👈 Left video looks better", size="lg")
right_button = gr.Button("Right video looks better 👉", size="lg")
# Add a hidden passkey input
with gr.Row():
with gr.Column(scale=1):
passkey_input = gr.Textbox(
label="",
placeholder="",
show_label=False,
container=False,
scale=0.15,
min_width=100,
elem_classes="invisible-textbox",
)
# Rankings section (hidden by default)
rankings_section = gr.Row(visible=False)
with rankings_section:
rankings_table = gr.DataFrame(
manager.get_rankings(),
label="Current Model Rankings",
headers=["Model", "Elo Rating"],
interactive=False,
)
# Hidden download section with file management
with gr.Row(visible=False) as download_section:
with gr.Column():
gr.Markdown("## Download Results File")
files = manager.get_data_files()
file_names = [file["name"] for file in files]
file_select = gr.Dropdown(
choices=file_names, label="Select Results File to Download", interactive=True
)
download_button = gr.Button("Download Selected File", size="sm")
download_output = gr.File(label="Download", visible=False)
gr.Markdown("## Reset Data")
with gr.Row():
reset_button = gr.Button("Reset All Data", size="sm", variant="stop")
reset_confirm = gr.Button("Confirm Reset", size="sm", variant="stop", visible=False)
reset_warning = gr.Markdown(
visible=False,
value="⚠️ **WARNING**: This will permanently delete all rankings and comparison data. Click 'Confirm Reset' to proceed.",
)
def show_reset_warning():
return {
reset_warning: gr.update(visible=True),
reset_confirm: gr.update(visible=True),
reset_button: gr.update(visible=False),
}
def reset_data():
try:
with manager.data_lock.acquire(timeout=10):
# Initialize fresh data
default_rating = 1000.0
elo_ratings = {model["name"]: default_rating for model in manager.models}
results = {"comparisons": []}
fresh_data = {"elo_ratings": elo_ratings, "results": results}
# Write fresh data to file
with open(manager.data_file, "w") as f:
json.dump(fresh_data, f, indent=2)
# Reset manager's data
manager.data = fresh_data
return {
reset_warning: gr.update(visible=False),
reset_confirm: gr.update(visible=False),
reset_button: gr.update(visible=True),
rankings_table: manager.get_rankings(),
}
except Timeout:
return {
reset_warning: gr.update(value="⚠️ Error: Could not acquire lock to reset data", visible=True),
reset_confirm: gr.update(visible=False),
reset_button: gr.update(visible=True),
}
reset_button.click(fn=show_reset_warning, inputs=[], outputs=[reset_warning, reset_confirm, reset_button])
reset_confirm.click(
fn=reset_data, inputs=[], outputs=[reset_warning, reset_confirm, reset_button, rankings_table]
)
def check_passkey(passkey: str):
"""Check if the entered passkey is correct and show/hide sections."""
correct_passkey = os.environ.get("PASSKEY", "")
is_visible = passkey == correct_passkey
return [
gr.Row(visible=is_visible), # download_section
gr.File(visible=is_visible), # download_output
gr.Row(visible=is_visible), # rankings_section
]
passkey_input.change(
fn=check_passkey,
inputs=[passkey_input],
outputs=[download_section, download_output, rankings_section],
)
def load_new_comparison():
"""Load a new random comparison."""
video_name = random.choice(manager.videos)
model_pair = manager.get_random_pair()
video_paths = manager.get_video_paths(video_name, model_pair)
# Update video labels
video_left.label = model_pair[0]["name"]
video_right.label = model_pair[1]["name"]
current_video_value = video_name
current_models_value = [model["name"] for model in model_pair]
return (
video_paths[0], # video_left
video_paths[1], # video_right
current_video_value, # current_video
current_models_value, # current_models
manager.get_rankings(), # rankings_table
)
def handle_choice(choice_index: int, video_name: str, current_models_value: List[str], current_username: str):
"""Handle the user's choice and update rankings."""
if not current_models_value or len(current_models_value) < 2:
print("Error: current_models is invalid:", current_models_value)
return gr.update(visible=False), gr.update(visible=False), "", [], manager.get_rankings()
winner = current_models_value[choice_index]
loser = current_models_value[1 - choice_index]
# Save the comparison result with username
manager.save_comparison(video_name, winner, loser, current_username)
# Load new comparison
return load_new_comparison()
left_button.click(
fn=lambda vid, models, user: handle_choice(0, vid, models, user),
inputs=[current_video, current_models, username],
outputs=[video_left, video_right, current_video, current_models, rankings_table],
)
right_button.click(
fn=lambda vid, models, user: handle_choice(1, vid, models, user),
inputs=[current_video, current_models, username],
outputs=[video_left, video_right, current_video, current_models, rankings_table],
)
demo.load(
fn=load_new_comparison,
inputs=[],
outputs=[video_left, video_right, current_video, current_models, rankings_table],
)
def download_file(file_name: str):
"""Prepare the selected file for download."""
if not file_name:
return None
file_path = os.path.join(manager.data_dir, file_name)
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file_name)[1]) as tmp_file:
with open(file_path, "rb") as f:
tmp_file.write(f.read())
return tmp_file.name
except Exception as e:
print(f"Error preparing download: {e}")
return None
download_button.click(
fn=download_file,
inputs=[file_select],
outputs=download_output,
)
return demo
if __name__ == "__main__":
local = True
data_dir = "./" if local else "/data"
demo = create_arena_interface(data_dir=data_dir)
demo.launch(share=True)