Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import random | |
| import os | |
| import requests # added import for API calls | |
| from dotenv import load_dotenv | |
| load_dotenv() | |
| # Define model names for randomization (extracted from your previous initial_data) | |
| model_names = [ | |
| "dalle_desc_25", | |
| "dalle_desc_50", | |
| "dalle_desc_100", | |
| "dalle_desc_150", | |
| "dalle_desc_250", | |
| "desc_25_threshold_250", | |
| "desc_25_threshold_500", | |
| "desc_25_threshold_1000", | |
| "desc_250_threshold_250", | |
| "desc_250_threshold_500", | |
| "desc_250_threshold_1000", | |
| "jpeg_scale_2", | |
| "jpeg_scale_4", | |
| "jpeg_scale_8", | |
| "jpeg_scale_16", | |
| "jpeg_scale_32", | |
| "sa30_desc_50", | |
| "sa30_desc_100", | |
| "sa30_desc_150", | |
| "sa30_desc_250", | |
| "sd30_desc_25", | |
| "sd35_desc_25", | |
| "sd35_desc_50", | |
| "sd35_desc_100", | |
| "sd35_desc_150", | |
| "sd35_desc_250" | |
| ] | |
| # Global variables for the image template and current state: | |
| images = [f"3d/{model}/OBJ.png" for model in model_names] | |
| current_images = [0, 0] | |
| current_obj = None # will store the object used in the current voting round | |
| # Set API host and access key from environment variables | |
| BACK_HOST = os.getenv("BACK_HOST") | |
| ACCESS_KEY = os.getenv("ACCESS_KEY") | |
| # List of objects to choose from (kept as-is) | |
| objs = ['axe', 'barrel', 'bed', 'bottle', 'canon', 'car', 'chair', 'chair2', 'chair3', 'chair4'] | |
| def get_new_images(): | |
| global current_images, current_obj | |
| random.seed() | |
| idx1, idx2 = random.sample(range(len(images)), 2) | |
| current_images = [idx1, idx2] | |
| obj = random.choice(objs) | |
| current_obj = obj # store the object for the current round | |
| new_images = [img.replace('OBJ', obj) for img in images] | |
| original = f"3d/original/{obj}.png" | |
| return { | |
| "original": original, | |
| "image1": new_images[idx1], | |
| "image2": new_images[idx2], | |
| "label1": "Left", | |
| "label2": "Right", | |
| "obj": obj # return the object in case it is needed | |
| } | |
| def vote_and_randomize(choice): | |
| global current_images, current_obj | |
| if choice == "left": | |
| winner_index = current_images[0] | |
| loser_index = current_images[1] | |
| else: | |
| winner_index = current_images[1] | |
| loser_index = current_images[0] | |
| winner_model = model_names[winner_index] | |
| loser_model = model_names[loser_index] | |
| # Use the current object generated during the image randomization | |
| obj = current_obj | |
| # Prepare payload for voting | |
| payload = { | |
| "winner": winner_model, | |
| "loser": loser_model, | |
| "object": obj | |
| } | |
| url = f"{BACK_HOST}/vote" | |
| headers = { | |
| "Authorization": f"Bearer {ACCESS_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| try: | |
| response = requests.post(url, headers=headers, json=payload) | |
| resp_json = response.json() | |
| if resp_json.get("message") == "Vote recorded successfully": | |
| message = f"Thanks for voting for {winner_model}!" | |
| else: | |
| message = "Error recording vote. Please try again." | |
| except Exception as e: | |
| message = "Error recording vote. Please try again." | |
| new_state = get_new_images() | |
| updated_leaderboard = get_leaderboard_data() # refresh leaderboard from API | |
| return ( | |
| message, | |
| new_state["original"], | |
| new_state["image1"], | |
| new_state["image2"], | |
| new_state["label1"], | |
| new_state["label2"], | |
| updated_leaderboard | |
| ) | |
| def start_voting(): | |
| # Get initial random images | |
| initial_state = get_new_images() | |
| return ( | |
| gr.update(visible=False), # Hide start button | |
| gr.update(visible=True), # Show voting container | |
| initial_state["original"], | |
| initial_state["image1"], | |
| initial_state["image2"], | |
| initial_state["label1"], | |
| initial_state["label2"] | |
| ) | |
| def get_leaderboard_data(): | |
| """Fetch leaderboard data from the API and transform it for display.""" | |
| headers = { | |
| "Authorization": f"Bearer {ACCESS_KEY}" | |
| } | |
| try: | |
| response = requests.get(f"{BACK_HOST}/get", headers=headers) | |
| if response.status_code == 200: | |
| data = response.json() | |
| print("leaderboard", data) | |
| # Transform the dictionary into a list of rows for the DataFrame | |
| leaderboard_list = [[name, elo, ""] for name, elo in data.items()] | |
| return leaderboard_list | |
| else: | |
| return [] | |
| except Exception as e: | |
| return [] | |
| def refresh_leaderboard(): | |
| """Refresh leaderboard data.""" | |
| return get_leaderboard_data() | |
| with gr.Blocks(css=""" | |
| #main-image { | |
| margin: auto; /* Center the image */ | |
| display: block; | |
| } | |
| """) as demo: | |
| with gr.Tabs() as tabs: # Remove elem_id, we don't need it anymore | |
| # Tab 1: Voting | |
| with gr.Tab("Voting"): | |
| gr.Markdown("### Vote for your favorite option!") | |
| # Start button (centered) | |
| with gr.Column(elem_id="start-container"): | |
| start_btn = gr.Button("Start!", scale=0.5) | |
| # Voting interface (initially hidden) | |
| with gr.Column(visible=False) as voting_container: | |
| # Image Comparison Grid | |
| # justify in the center | |
| with gr.Row(equal_height=True): | |
| main_image = gr.Image(value=None, label="Original", interactive=False, show_download_button=True, elem_id="main-image", scale=0.25) | |
| with gr.Row(): | |
| left_image = gr.Image(value=None, label="Left Option", interactive=False, show_download_button=False) | |
| right_image = gr.Image(value=None, label="Right Option", interactive=False, show_download_button=False) | |
| with gr.Row(): | |
| vote_1 = gr.Button(value="") | |
| vote_2 = gr.Button(value="") | |
| output = gr.Textbox(label="Vote Result", interactive=False) | |
| # Tab 2: Leaderboard | |
| with gr.Tab("Leaderboard") as leaderboard_tab: | |
| gr.Markdown("### Leaderboard") | |
| leaderboard_table = gr.DataFrame( | |
| headers=["Name", "Elo", "Description"], | |
| value=get_leaderboard_data(), | |
| interactive=False | |
| ) | |
| # Add a refresh button | |
| refresh_btn = gr.Button("Refresh Leaderboard") | |
| # Handle start button click | |
| start_btn.click( | |
| fn=start_voting, | |
| outputs=[ | |
| start_btn, | |
| voting_container, | |
| main_image, | |
| left_image, | |
| right_image, | |
| vote_1, | |
| vote_2 | |
| ] | |
| ) | |
| # Handle voting buttons | |
| vote_1.click( | |
| fn=lambda: vote_and_randomize("left"), | |
| outputs=[output, main_image, left_image, right_image, vote_1, vote_2, leaderboard_table] | |
| ) | |
| vote_2.click( | |
| fn=lambda: vote_and_randomize("right"), | |
| outputs=[output, main_image, left_image, right_image, vote_1, vote_2, leaderboard_table] | |
| ) | |
| # Replace the tabs.change with a refresh button click handler | |
| refresh_btn.click( | |
| fn=refresh_leaderboard, | |
| outputs=leaderboard_table | |
| ) | |
| # Also refresh when the leaderboard tab is selected | |
| leaderboard_tab.select( | |
| fn=refresh_leaderboard, | |
| outputs=leaderboard_table | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |