import gradio as gr import random from threading import Thread from queue import Queue # Import our new modules import config import backend # --- HELPER FUNCTIONS (Unchanged) --- def get_random_question(domain): data_conf = config.DATASET_CONFIG[domain] dataset = data_conf["dataset"] if not dataset: return "Failed to load dataset.", "N/A" random_index = random.randint(0, len(dataset) - 1) sample = dataset[random_index] if domain == "Math": question = sample[data_conf["question_col"]] answer = sample[data_conf["answer_col"]] elif domain == "Bio": instruction = sample[data_conf["instruction_col"]] bio_input = sample[data_conf["input_col"]] answer = sample[data_conf["answer_col"]] if bio_input and bio_input.strip(): question = f"**Instruction:**\n{instruction}\n\n**Input:**\n{bio_input}" else: question = instruction return question, answer def update_domain_settings(domain): models = list(config.ALL_MODELS[domain].keys()) def_base = next((m for m in models if "Base" in m), models[0]) def_ft = next((m for m in models if "Finetuned" in m), models[0]) q, a = get_random_question(domain) return [ gr.Dropdown(choices=models, value=def_base), gr.Dropdown(choices=models, value=def_ft), gr.Textbox(value=q), a, gr.Markdown(visible=False) ] def load_next_question(domain): q, a = get_random_question(domain) return [gr.Textbox(value=q), a, gr.Markdown(visible=False, value="")] def reveal_answer(hidden_answer): return gr.Markdown(value=f"**Ground Truth Answer:**\n\n{hidden_answer}", visible=True) # --- CORE LOGIC (REBUILT FOR TRUE PARALLEL STREAMING) --- def stream_to_queue(model_id, prompt, lane, queue, key): """ A worker function that runs in a thread. It calls the streaming API and puts tokens into the queue. """ try: # call_modal_api is a generator for token in backend.call_modal_api(model_id, prompt, lane): queue.put((key, token)) except Exception as e: queue.put((key, f"\n\nTHREAD ERROR: {e}")) finally: # When the stream is done, put a 'None' sentinel queue.put((key, None)) def run_comparison(domain, question, model_1_name, model_2_name): # 1. Get IDs id_1 = config.ALL_MODELS[domain].get(model_1_name) id_2 = config.ALL_MODELS[domain].get(model_2_name) # 2. Ask the Smart Router lane_for_m1, lane_for_m2 = backend.router.get_routing_plan(id_1, id_2) # 3. Create the Queue and Threads q = Queue() Thread( target=stream_to_queue, args=(id_1, question, lane_for_m1, q, 'm1') ).start() Thread( target=stream_to_queue, args=(id_2, question, lane_for_m2, q, 'm2') ).start() # 4. Listen to the Queue text1 = "" text2 = "" m1_done = False m2_done = False # Clear boxes and start yield "", "", gr.Markdown(visible=False) while not (m1_done and m2_done): # Wait for the next token from *either* thread try: key, token = q.get() except Exception as e: # This should ideally not happen print(f"Queue error: {e}") continue # Check for the 'None' sentinel if token is None: if key == 'm1': m1_done = True elif key == 'm2': m2_done = True else: # Append the new token if key == 'm1': text1 += token elif key == 'm2': text2 += token # Yield the updated full text yield text1, text2, gr.Markdown(visible=False) # --- UI BUILD (Unchanged) --- initial_question, initial_answer = get_random_question("Math") with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # 🔬 LLM Finetuning Arena ### Comparing Finetuned vs. Base Models on Specialized Tasks """ ) hidden_answer_state = gr.State(value=initial_answer) with gr.Row(): domain_radio = gr.Radio( ["Math", "Bio"], label="1. Select Domain", value="Math" ) with gr.Row(): question_box = gr.Textbox( label="2. Question Prompt (Editable)", value=initial_question, lines=5, scale=4 ) next_btn = gr.Button("Load Random Question 🔄", scale=1, min_width=100) with gr.Row(): model_1_dd = gr.Dropdown( label="3. Select Model 1 (Left)", choices=list(config.ALL_MODELS["Math"].keys()), value=next((m for m in config.ALL_MODELS["Math"] if "Base" in m)) ) model_2_dd = gr.Dropdown( label="4. Select Model 2 (Right)", choices=list(config.ALL_MODELS["Math"].keys()), value=next((m for m in config.ALL_MODELS["Math"] if "Finetuned" in m)) ) with gr.Row(): run_btn = gr.Button("🚀 Run Comparison", variant="primary", scale=3) show_answer_btn = gr.Button("Show Ground Truth Answer", scale=1) answer_display_box = gr.Markdown(label="Ground Truth Answer", visible=False) gr.Markdown("---") with gr.Row(): output_1_box = gr.Markdown(label="Output: Model 1") output_2_box = gr.Markdown(label="Output: Model 2") # --- EVENTS (Unchanged) --- domain_radio.change( fn=update_domain_settings, inputs=[domain_radio], outputs=[model_1_dd, model_2_dd, question_box, hidden_answer_state, answer_display_box] ) next_btn.click( fn=load_next_question, inputs=[domain_radio], outputs=[question_box, hidden_answer_state, answer_display_box] ) show_answer_btn.click( fn=reveal_answer, inputs=[hidden_answer_state], outputs=[answer_display_box] ) run_btn.click( fn=run_comparison, inputs=[domain_radio, question_box, model_1_dd, model_2_dd], outputs=[output_1_box, output_2_box, answer_display_box] ) if __name__ == "__main__": if not config.MY_AUTH_TOKEN: print("⚠️ WARNING: ARENA_AUTH_TOKEN is not set.") demo.launch()