File size: 19,512 Bytes
6bd22b5
ad95ef1
6bd22b5
 
 
 
 
ad95ef1
 
 
 
 
6bd22b5
 
 
 
 
 
 
 
ad95ef1
 
 
 
 
 
 
 
fd406c7
 
275b4f2
fd406c7
275b4f2
fd406c7
ad95ef1
fd406c7
ad95ef1
 
fd406c7
 
ad95ef1
fd406c7
 
 
 
 
ad95ef1
e6cb750
 
 
 
6bd22b5
ad95ef1
e6cb750
6bd22b5
 
 
 
 
 
 
 
 
 
e6cb750
6bd22b5
 
 
e6cb750
6bd22b5
ad95ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6bd22b5
 
 
 
 
 
 
 
 
 
 
d6c6a2d
6bd22b5
ad95ef1
 
 
 
 
 
6bd22b5
 
d6c6a2d
fc5f2ab
6bd22b5
7a33ddf
ad95ef1
 
 
 
 
 
 
 
 
 
 
 
6bd22b5
e6cb750
ad95ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5a87a2
ad95ef1
 
8758212
a5a87a2
ad95ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc5f2ab
ad95ef1
 
 
 
 
6bd22b5
 
ad95ef1
 
6bd22b5
ad95ef1
 
 
 
 
6bd22b5
 
e6cb750
 
ad95ef1
72b6692
e6cb750
ad95ef1
 
 
 
 
 
 
 
 
6bd22b5
 
 
 
 
 
 
ad95ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e6cb750
 
 
72b6692
e6cb750
72b6692
ad95ef1
e6cb750
ad95ef1
 
 
275b4f2
 
ad95ef1
 
 
 
 
 
275b4f2
 
 
 
ad95ef1
 
 
 
 
 
6bd22b5
ad95ef1
 
 
 
 
 
e6cb750
ad95ef1
 
 
 
 
 
 
e6cb750
ad95ef1
 
 
 
 
 
 
 
 
 
fd406c7
 
 
 
ad95ef1
 
e6cb750
ad95ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e24705b
 
ad95ef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e24705b
 
 
 
ad95ef1
 
e6cb750
ad95ef1
 
 
e6cb750
ad95ef1
 
 
 
e24705b
 
 
 
ad95ef1
e6cb750
ad95ef1
e24705b
 
 
 
 
 
ad95ef1
6bd22b5
7a33ddf
6bd22b5
 
 
 
 
 
 
7a33ddf
6bd22b5
 
69c5213
 
 
 
6bd22b5
 
 
 
 
 
 
 
 
 
 
69c5213
 
 
 
6bd22b5
 
69c5213
 
 
 
 
6bd22b5
 
 
 
69c5213
6bd22b5
 
e6cb750
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
import gradio as gr
from typing import Optional, Dict, List
from datetime import datetime

from .config import AppConfig
from .session_manager import HackerNewsFineTuner

# --- Constants for Labels ---
LABEL_FAV = "👍"
LABEL_NEU = "😐"
LABEL_DIS = "👎"

# --- Session Wrappers ---

def refresh_wrapper(app):
    if app is None or callable(app) or isinstance(app, type):
        print("Initializing new HackerNewsFineTuner session...")
        app = HackerNewsFineTuner(AppConfig)
    
    # Run the refresh logic
    # choices_list is a simple list of strings: ["Title 1", "Title 2", ...]
    choices_list, log_update = app.refresh_data_and_model()
    
    # Reset user labels
    empty_labels = {}
    
    return app, choices_list, empty_labels, log_update

def update_hub_interactive(app, username: Optional[str] = None):
    is_logged_in = username is not None
    has_model_tuned = app is not None and bool(app.last_hn_dataset)
    
    return gr.update(interactive=is_logged_in), gr.update(interactive=is_logged_in and has_model_tuned)

def on_app_load(app, profile: Optional[gr.OAuthProfile] = None):
    # 1. Initialize/Refresh Session
    app, stories, labels, text_update = refresh_wrapper(app)
    
    # 2. Extract Username safely
    username = profile.username if profile else None
    
    # 3. Get UI Updates using the helper
    repo_update, push_update = update_hub_interactive(app, username)

    # Return 7 items: App state, Data updates (3), Hub updates (2), Username state (1)
    return app, stories, labels, text_update, repo_update, push_update, username

def update_repo_preview(username, repo_name):
    """Updates the markdown preview to show 'username/repo_name'."""
    if not username:
        return "⚠️ Sign in to see the target repository path."
    
    clean_repo = repo_name.strip() if repo_name else "..."
    return f"Target Repository: **`{username}/{clean_repo}`**"

def import_wrapper(app, file):
    return app.import_additional_dataset(file)

def export_wrapper(app):
    return app.export_dataset()

def download_model_wrapper(app):
    return app.download_model()

def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
    if oauth_token is None:
        return "⚠️ You must be logged in to push to the Hub. Please sign in above."
    token_str = oauth_token.token
    return app.upload_model(repo_name, token_str)

def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
    """
    Parses the Stories and Labels to extract Positive and Negative indices.
    stories: List of titles
    labels: Dictionary of {index: LABEL_FAV | LABEL_DIS | LABEL_NEU}
    """
    pos_ids = []
    neg_ids = []
    
    # Iterate through all available stories by index
    for i in range(len(stories)):
        # Get label for this index, default to Neutral if not set
        label = labels.get(i, LABEL_NEU)
        
        if label == LABEL_FAV:
            pos_ids.append(i)
        elif label == LABEL_DIS:
            neg_ids.append(i)
            
    return app.training(pos_ids, neg_ids)

def vibe_check_wrapper(app, text):
    return app.get_vibe_check(text)

def mood_feed_wrapper(app):
    return app.fetch_and_display_mood_feed()


# --- Interface Setup ---

def build_interface() -> gr.Blocks:
    with gr.Blocks(title="EmbeddingGemma Tuning Lab") as demo:
        session_state = gr.State()
        username_state = gr.State()
        
        # State variables for the Feed List and User Choices
        stories_state = gr.State([]) 
        labels_state = gr.State({})
        reset_counter = gr.State(0)

        with gr.Column():
            gr.Markdown("# 🤖 EmbeddingGemma Tuning Lab: Fine-Tuning and Mood Reader")
            gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-tuning-lab/blob/main/README.md) for more details.")
        
        with gr.Tab("⚙️ Train & Export"):
            
            # --- Model Indicator ---
            gr.Dropdown(
                choices=[f"{AppConfig.MODEL_NAME}"], 
                value=f"{AppConfig.MODEL_NAME}", 
                label="Base Model for Fine-tuning", 
                interactive=False
            )
        
            # --- Step 0: Login ---
            with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
                gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
                with gr.Row():
                    gr.LoginButton(value="Sign in with Hugging Face")
                    with gr.Column(scale=3):
                        gr.Markdown("")
            
            # --- Step 1: Data Selection ---
            with gr.Accordion("1️⃣ Step 1: Select Data Source", open=True):
                gr.Markdown("Select titles from the live Hacker News feed **OR** upload your own CSV dataset to prepare your training data.")
                
                with gr.Column():
                    # Option A: Live Feed (Radio List)
                    with gr.Accordion("Option A: Live Hacker News Feed", open=True):
                        gr.Markdown("Rate the stories below to define your vibe.\n\n**⚠️ Note: You must select at least one Favorite and one Dislike to run training.**")
                        
                        with gr.Row():
                            reset_all_btn = gr.Button("Reset Selection ↺", variant="secondary", scale=1)
                            with gr.Column(scale=3):
                                gr.Markdown("")
                        
                        # Dynamic rendering of the story list
                        @gr.render(inputs=[stories_state, reset_counter])
                        def render_story_list(stories, _counter):
                            if not stories:
                                gr.Markdown("*No stories loaded. Click 'Reset Model & Fine-tuning state' to fetch data.*")
                                return
                            
                            for i, title in enumerate(stories[:10]):
                                with gr.Row(variant="compact", elem_id=f"story_row_{i}"):
                                    # Title
                                    with gr.Column(scale=2):
                                    	gr.Markdown(f"{title}")
                                    
                                    # Radio Selection
                                    radio = gr.Radio(
                                        choices=[LABEL_FAV, LABEL_NEU, LABEL_DIS],
                                        value=LABEL_NEU,
                                        show_label=False,
                                        container=False,
                                        min_width=80,
                                        scale=1,
                                        interactive=True
                                    )
                                    
                                    # Update logic
                                    def update_label(new_val, current_labels, idx=i):
                                        current_labels[idx] = new_val
                                        return current_labels

                                    radio.change(
                                        fn=update_label,
                                        inputs=[radio, labels_state],
                                        outputs=[labels_state]
                                    )

                    # Option B: Upload
                    with gr.Accordion("Option B: Upload Custom Dataset", open=False):
                        gr.Markdown("Upload a CSV file with columns (no header required, or header ignored if present): `Anchor`, `Positive`, `Negative`.")
                        gr.Markdown("See also: [example_training.dataset.csv](https://huggingface.co/spaces/google/embeddinggemma-tuning-lab/blob/main/example_training_dataset.csv)<br>Example:<br>`MY_FAVORITE_NEWS,Good Title,Bad Title`")
                        import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=100)

            # --- Step 2: Training ---
            with gr.Accordion("2️⃣ Step 2: Run Tuning", open=True):
                gr.Markdown("Fine-tune the model using the data selected or uploaded above.")
                
                with gr.Row():
                    run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1)
                    clear_reload_btn = gr.Button("Reset Model & Fine-tuning state", scale=1)
                
                output = gr.Textbox(lines=10, label="Training Logs & Search Results", value="Waiting to start...", autoscroll=True)

            # --- Step 3: Push to Hub ---
            with gr.Accordion("3️⃣ Step 3: Save to Hugging Face Hub (Optional)", open=False):
                gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
                
                with gr.Row():
                    repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
                    push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
                
                repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
                
                push_status = gr.Markdown("")

            # --- Step 4: Downloads ---
            with gr.Accordion("4️⃣ Step 4: Download Artifacts", open=False):
                gr.Markdown("Export your combined dataset or download the fine-tuned model locally.")

                with gr.Row():
                    download_dataset_btn = gr.Button("💾 Export Dataset", interactive=False)
                    download_model_btn = gr.Button("⬇️ Download Model ZIP", interactive=False)
                
                download_status = gr.Markdown("Ready.")
                
                with gr.Row():
                    dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
                    model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)

            # --- Interaction Logic ---
            
            action_buttons = [
                clear_reload_btn,
                run_training_btn,
                download_dataset_btn,
                download_model_btn
            ]
            
            def set_interactivity(interactive: bool):
                """Helper to lock/unlock all main action buttons."""
                return [gr.update(interactive=interactive) for _ in action_buttons]
            
            # 1. App Startup
            # ----------------
            demo.load(
                fn=lambda: set_interactivity(False), outputs=action_buttons
            ).then(
                fn=on_app_load, 
                inputs=[session_state], 
                outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
            ).then(
                fn=update_repo_preview,
                inputs=[username_state, repo_name_input],
                outputs=[repo_id_preview]
            ).then(
                fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
            )
            
            # 2. Reset / Refresh / Clear Selections
            # ----------------
            clear_reload_btn.click(
                fn=lambda: set_interactivity(False), outputs=action_buttons
            ).then(
                fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
            ).then(
                fn=refresh_wrapper, 
                inputs=[session_state], 
                outputs=[session_state, stories_state, labels_state, output]
            ).then(
                fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
            ).then(
                fn=update_hub_interactive,
                inputs=[session_state, username_state],
                outputs=[repo_name_input, push_to_hub_btn]
            )
            
            # Reset Selection Button Logic
            def reset_all_selections(counter):
                # Returns: (incremented counter, empty dict for labels)
                return counter + 1, {}

            reset_all_btn.click(
                fn=reset_all_selections,
                inputs=[reset_counter],
                outputs=[reset_counter, labels_state]
            )
            
            # 3. Import Data
            # ----------------
            import_file.change(
                fn=import_wrapper, 
                inputs=[session_state, import_file], 
                outputs=[download_status]
            )
            
            # 4. Run Training
            # ----------------
            run_training_btn.click(
                fn=lambda: set_interactivity(False), outputs=action_buttons
            ).then(
                fn=training_wrapper, 
                inputs=[session_state, stories_state, labels_state], 
                outputs=[output]
            ).then(
                # Unlock all buttons (including downloads now that we have a model)
                fn=lambda: set_interactivity(True), outputs=action_buttons
            ).then(
                fn=update_hub_interactive,
                inputs=[session_state, username_state],
                outputs=[repo_name_input, push_to_hub_btn]
            )
            
            # 5. Downloads
            # ----------------
            download_dataset_btn.click(
                fn=export_wrapper,
                inputs=[session_state],
                outputs=[dataset_output]
            ).then(
                # Just show the file output if it exists
                lambda p: gr.update(visible=True) if p else gr.update(), 
                inputs=[dataset_output], 
                outputs=[dataset_output]
            )

            download_model_btn.click(
                # Lock UI
                fn=lambda: set_interactivity(False), outputs=action_buttons
            ).then(
                fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
            ).then(
                # Reset previous outputs and show "Zipping..."
                fn=lambda: [gr.update(value=None, visible=False), "⏳ Zipping model..."], 
                outputs=[model_output, download_status]
            ).then(
                # Generate Zip
                fn=download_model_wrapper,
                inputs=[session_state],
                outputs=[model_output]
            ).then(
                # Update UI with result
                fn=lambda p: [gr.update(visible=p is not None, value=p), "✅ ZIP ready." if p else "❌ Zipping failed."], 
                inputs=[model_output], 
                outputs=[model_output, download_status]
            ).then(
                # Unlock UI
                fn=lambda: set_interactivity(True), outputs=action_buttons
            ).then(
                fn=update_hub_interactive,
                inputs=[session_state, username_state],
                outputs=[repo_name_input, push_to_hub_btn]
            )
            
            # 6. Push to Hub
            # ----------------
            repo_name_input.change(
                fn=update_repo_preview,
                inputs=[username_state, repo_name_input],
                outputs=[repo_id_preview]
            )

            push_to_hub_btn.click(
                fn=lambda: set_interactivity(False), outputs=action_buttons
            ).then(
                fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
            ).then(
                fn=push_to_hub_wrapper,
                inputs=[session_state, repo_name_input],
                outputs=[push_status]
            ).then(
                fn=lambda: set_interactivity(True), outputs=action_buttons
            ).then(
                fn=update_hub_interactive,
                inputs=[session_state, username_state],
                outputs=[repo_name_input, push_to_hub_btn]
            )

        with gr.Tab("📰 Live Ranked Feed"):
            with gr.Column():
                gr.Markdown(f"## Live Hacker News Feed Vibe")
                gr.Markdown(f"This feed uses the current model (base or fine-tuned) to score the vibe of live Hacker News stories against **`{AppConfig.QUERY_ANCHOR}`**.")
                feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
                refresh_button = gr.Button("Refresh Feed 🔄", size="lg", variant="primary")
                refresh_button.click(fn=mood_feed_wrapper, inputs=[session_state], outputs=feed_output)

        with gr.Tab("🧪 Vibe Check Playground"):
            with gr.Column():
                gr.Markdown(f"## News Similarity Check")
                gr.Markdown(f"Enter text to see its similarity to **`{AppConfig.QUERY_ANCHOR}`**.<br>**Vibe Key:** <span style='color:green'>Green = High</span>, <span style='color:yellow'>Yellow = Neutral</span>, <span style='color:red'>Red = Low</span>")

                news_input = gr.Textbox(label="Enter News Title or Summary", lines=3, render=False)

                gr.Examples(
                    examples=[
                        "Global Markets Rally as Inflation Data Shows Unexpected Drop for Third Consecutive Month",
                        "Astronomers Detect Strong Oxygen Signature on Potentially Habitable Exoplanet",
                        "City Council Approves Controversial Plan to Ban Cars from Downtown District by 2027",
                        "Tech Giant Unveils Prototype for \"Invisible\" AR Glasses, Promising a Screen-Free Future",
                        "Local Library Receives Overdue Book Checked Out in 1948 With An Anonymous Apology Note"
                    ],
                    inputs=news_input,
                    label="Try these examples"
                )
                
                news_input.render()
                vibe_check_btn = gr.Button("Check Similarity", variant="primary")
                
                session_info_display = gr.Markdown()

                with gr.Column():
                    vibe_score = gr.Textbox(label="Score", value="N/A", interactive=False)
                    vibe_lamp = gr.Textbox(label="Mood Lamp", max_lines=1, elem_id="mood_lamp", interactive=False)
                    vibe_status = gr.Textbox(label="Status", value="...", interactive=False)
                    style_thml = gr.HTML(value="<style>#mood_lamp input {background-color: gray;}</style>")
                
                vibe_check_btn.click(
                    fn=vibe_check_wrapper, 
                    inputs=[session_state, news_input], 
                    outputs=[vibe_score, vibe_status, style_thml, session_info_display]
                )

    return demo