File size: 6,568 Bytes
ec0af28
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import gradio as gr
import random
from threading import Thread
from queue import Queue

# Import our new modules
import config
import backend

# --- HELPER FUNCTIONS (Unchanged) ---
def get_random_question(domain):
    data_conf = config.DATASET_CONFIG[domain]
    dataset = data_conf["dataset"]
    
    if not dataset:
        return "Failed to load dataset.", "N/A"
        
    random_index = random.randint(0, len(dataset) - 1)
    sample = dataset[random_index]
    
    if domain == "Math":
        question = sample[data_conf["question_col"]]
        answer = sample[data_conf["answer_col"]]
    elif domain == "Bio":
        instruction = sample[data_conf["instruction_col"]]
        bio_input = sample[data_conf["input_col"]]
        answer = sample[data_conf["answer_col"]]
        if bio_input and bio_input.strip():
            question = f"**Instruction:**\n{instruction}\n\n**Input:**\n{bio_input}"
        else:
            question = instruction
            
    return question, answer

def update_domain_settings(domain):
    models = list(config.ALL_MODELS[domain].keys())
    def_base = next((m for m in models if "Base" in m), models[0])
    def_ft = next((m for m in models if "Finetuned" in m), models[0])
    
    q, a = get_random_question(domain)
    return [
        gr.Dropdown(choices=models, value=def_base),
        gr.Dropdown(choices=models, value=def_ft),
        gr.Textbox(value=q),
        a,
        gr.Markdown(visible=False)
    ]

def load_next_question(domain):
    q, a = get_random_question(domain)
    return [gr.Textbox(value=q), a, gr.Markdown(visible=False, value="")]

def reveal_answer(hidden_answer):
    return gr.Markdown(value=f"**Ground Truth Answer:**\n\n{hidden_answer}", visible=True)

# --- CORE LOGIC (REBUILT FOR TRUE PARALLEL STREAMING) ---

def stream_to_queue(model_id, prompt, lane, queue, key):
    """

    A worker function that runs in a thread.

    It calls the streaming API and puts tokens into the queue.

    """
    try:
        # call_modal_api is a generator
        for token in backend.call_modal_api(model_id, prompt, lane):
            queue.put((key, token))
    except Exception as e:
        queue.put((key, f"\n\nTHREAD ERROR: {e}"))
    finally:
        # When the stream is done, put a 'None' sentinel
        queue.put((key, None))

def run_comparison(domain, question, model_1_name, model_2_name):
    # 1. Get IDs
    id_1 = config.ALL_MODELS[domain].get(model_1_name)
    id_2 = config.ALL_MODELS[domain].get(model_2_name)
    
    # 2. Ask the Smart Router
    lane_for_m1, lane_for_m2 = backend.router.get_routing_plan(id_1, id_2)
    
    # 3. Create the Queue and Threads
    q = Queue()
    
    Thread(
        target=stream_to_queue, 
        args=(id_1, question, lane_for_m1, q, 'm1')
    ).start()
    
    Thread(
        target=stream_to_queue, 
        args=(id_2, question, lane_for_m2, q, 'm2')
    ).start()

    # 4. Listen to the Queue
    text1 = ""
    text2 = ""
    m1_done = False
    m2_done = False
    
    # Clear boxes and start
    yield "", "", gr.Markdown(visible=False)

    while not (m1_done and m2_done):
        # Wait for the next token from *either* thread
        try:
            key, token = q.get()
        except Exception as e:
            # This should ideally not happen
            print(f"Queue error: {e}")
            continue

        # Check for the 'None' sentinel
        if token is None:
            if key == 'm1':
                m1_done = True
            elif key == 'm2':
                m2_done = True
        else:
            # Append the new token
            if key == 'm1':
                text1 += token
            elif key == 'm2':
                text2 += token
        
        # Yield the updated full text
        yield text1, text2, gr.Markdown(visible=False)


# --- UI BUILD (Unchanged) ---
initial_question, initial_answer = get_random_question("Math")

with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """

        # πŸ”¬ LLM Finetuning Arena

        ### Comparing Finetuned vs. Base Models on Specialized Tasks

        """
    )
    
    hidden_answer_state = gr.State(value=initial_answer)
    
    with gr.Row():
        domain_radio = gr.Radio(
            ["Math", "Bio"], label="1. Select Domain", value="Math"
        )
    
    with gr.Row():
        question_box = gr.Textbox(
            label="2. Question Prompt (Editable)", 
            value=initial_question, lines=5, scale=4
        )
        next_btn = gr.Button("Load Random Question πŸ”„", scale=1, min_width=100)
        
    with gr.Row():
        model_1_dd = gr.Dropdown(
            label="3. Select Model 1 (Left)", 
            choices=list(config.ALL_MODELS["Math"].keys()),
            value=next((m for m in config.ALL_MODELS["Math"] if "Base" in m))
        )
        model_2_dd = gr.Dropdown(
            label="4. Select Model 2 (Right)", 
            choices=list(config.ALL_MODELS["Math"].keys()),
            value=next((m for m in config.ALL_MODELS["Math"] if "Finetuned" in m))
        )
        
    with gr.Row():
        run_btn = gr.Button("πŸš€ Run Comparison", variant="primary", scale=3)
        show_answer_btn = gr.Button("Show Ground Truth Answer", scale=1)

    answer_display_box = gr.Markdown(label="Ground Truth Answer", visible=False)
    
    gr.Markdown("---")
    
    with gr.Row():
        output_1_box = gr.Markdown(label="Output: Model 1")
        output_2_box = gr.Markdown(label="Output: Model 2")

    # --- EVENTS (Unchanged) ---
    domain_radio.change(
        fn=update_domain_settings,
        inputs=[domain_radio],
        outputs=[model_1_dd, model_2_dd, question_box, hidden_answer_state, answer_display_box]
    )
    
    next_btn.click(
        fn=load_next_question,
        inputs=[domain_radio],
        outputs=[question_box, hidden_answer_state, answer_display_box]
    )
    
    show_answer_btn.click(
        fn=reveal_answer,
        inputs=[hidden_answer_state],
        outputs=[answer_display_box]
    )
    
    run_btn.click(
        fn=run_comparison,
        inputs=[domain_radio, question_box, model_1_dd, model_2_dd],
        outputs=[output_1_box, output_2_box, answer_display_box]
    )

if __name__ == "__main__":
    if not config.MY_AUTH_TOKEN:
        print("⚠️ WARNING: ARENA_AUTH_TOKEN is not set.")
    demo.launch()