bebechien commited on
Commit
6bd22b5
Β·
verified Β·
1 Parent(s): 06a5269

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. app.py +2 -525
  2. src/session_manager.py +306 -0
  3. src/ui.py +220 -0
  4. src/vibe_logic.py +1 -1
app.py CHANGED
@@ -1,530 +1,7 @@
1
- import gradio as gr
2
- import os
3
- import shutil
4
- import time
5
- import csv
6
- import uuid
7
- from itertools import cycle
8
- from typing import List, Iterable, Tuple, Optional, Callable
9
- from datetime import datetime
10
-
11
- # Import modules
12
- from src.data_fetcher import read_hacker_news_rss, format_published_time
13
- from src.model_trainer import (
14
- authenticate_hf,
15
- train_with_dataset,
16
- get_top_hits,
17
- load_embedding_model,
18
- upload_model_to_hub
19
- )
20
- from src.config import AppConfig
21
- from src.vibe_logic import VibeChecker
22
- from sentence_transformers import SentenceTransformer
23
-
24
- # --- Main Application Class (Session Scoped) ---
25
-
26
- class HackerNewsFineTuner:
27
- """
28
- Encapsulates all application logic and state for a single user session.
29
- """
30
-
31
- def __init__(self, config: AppConfig = AppConfig):
32
- # --- Dependencies ---
33
- self.config = config
34
-
35
- # --- Session Identification ---
36
- self.session_id = str(uuid.uuid4())
37
-
38
- # Define session-specific paths to allow simultaneous training
39
- self.session_root = self.config.ARTIFACTS_DIR / self.session_id
40
- self.output_dir = self.session_root / "embedding_gemma_finetuned"
41
- self.dataset_export_file = self.session_root / "training_dataset.csv"
42
-
43
- # Setup directories
44
- os.makedirs(self.output_dir, exist_ok=True)
45
- print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}")
46
-
47
- # --- Application State ---
48
- self.model: Optional[SentenceTransformer] = None
49
- self.vibe_checker: Optional[VibeChecker] = None
50
- self.titles: List[str] = []
51
- self.target_titles: List[str] = []
52
- self.number_list: List[int] = []
53
- self.last_hn_dataset: List[List[str]] = []
54
- self.imported_dataset: List[List[str]] = []
55
-
56
- # Authenticate once (global)
57
- authenticate_hf(self.config.HF_TOKEN)
58
-
59
- def _update_vibe_checker(self):
60
- """Initializes or updates the VibeChecker with the current model state."""
61
- if self.model:
62
- self.vibe_checker = VibeChecker(
63
- model=self.model,
64
- query_anchor=self.config.QUERY_ANCHOR,
65
- task_name=self.config.TASK_NAME
66
- )
67
- else:
68
- self.vibe_checker = None
69
-
70
- ## Data and Model Management ##
71
-
72
- def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
73
- """
74
- Reloads model and fetches data.
75
- """
76
- print(f"[{self.session_id}] Reloading model and data...")
77
-
78
- self.last_hn_dataset = []
79
- self.imported_dataset = []
80
-
81
- # 1. Reload the base embedding model
82
- try:
83
- self.model = load_embedding_model(self.config.MODEL_NAME)
84
- self._update_vibe_checker()
85
- except Exception as e:
86
- error_msg = f"CRITICAL ERROR: Model failed to load. {e}"
87
- print(error_msg)
88
- self.model = None
89
- self._update_vibe_checker()
90
- return (
91
- gr.update(choices=[], label="Model Load Failed"),
92
- gr.update(value=error_msg)
93
- )
94
-
95
- # 2. Fetch fresh news data
96
- news_feed, status_msg = read_hacker_news_rss(self.config)
97
- titles_out, target_titles_out = [], []
98
- status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
99
-
100
- if news_feed is not None and news_feed.entries:
101
- titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
102
- target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
103
- else:
104
- titles_out = ["Error fetching news.", "Check console."]
105
- gr.Warning(f"Data reload failed. {status_msg}")
106
-
107
- self.titles = titles_out
108
- self.target_titles = target_titles_out
109
- self.number_list = list(range(len(self.titles)))
110
-
111
- return (
112
- gr.update(
113
- choices=self.titles,
114
- label=f"Hacker News Top {len(self.titles)} (Select your favorites)"
115
- ),
116
- gr.update(value=status_value)
117
- )
118
-
119
- # --- Import Dataset/Export ---
120
- def import_additional_dataset(self, file_path: str) -> str:
121
- if not file_path:
122
- return "Please upload a CSV file."
123
- new_dataset, num_imported = [], 0
124
- try:
125
- with open(file_path, 'r', newline='', encoding='utf-8') as f:
126
- reader = csv.reader(f)
127
- try:
128
- header = next(reader)
129
- if not (header and header[0].lower().strip() == 'anchor'):
130
- f.seek(0)
131
- except StopIteration:
132
- return "Error: Uploaded file is empty."
133
-
134
- for row in reader:
135
- if len(row) == 3:
136
- new_dataset.append([s.strip() for s in row])
137
- num_imported += 1
138
- if num_imported == 0:
139
- raise ValueError("No valid rows found.")
140
- self.imported_dataset = new_dataset
141
- return f"Imported {num_imported} triplets."
142
- except Exception as e:
143
- return f"Import failed: {e}"
144
-
145
- def export_dataset(self) -> Optional[str]:
146
- if not self.last_hn_dataset:
147
- gr.Warning("No dataset generated yet.")
148
- return None
149
-
150
- file_path = self.dataset_export_file
151
- try:
152
- with open(file_path, 'w', newline='', encoding='utf-8') as f:
153
- writer = csv.writer(f)
154
- writer.writerow(['Anchor', 'Positive', 'Negative'])
155
- writer.writerows(self.last_hn_dataset)
156
- gr.Info(f"Dataset exported.")
157
- return str(file_path)
158
- except Exception as e:
159
- gr.Error(f"Export failed: {e}")
160
- return None
161
-
162
- def download_model(self) -> Optional[str]:
163
- if not os.path.exists(self.output_dir):
164
- gr.Warning("No model trained yet.")
165
- return None
166
-
167
- timestamp = int(time.time())
168
- try:
169
- base_name = self.session_root / f"model_finetuned_{timestamp}"
170
- archive_path = shutil.make_archive(
171
- base_name=str(base_name),
172
- format='zip',
173
- root_dir=self.output_dir,
174
- )
175
- gr.Info(f"Model zipped.")
176
- return archive_path
177
- except Exception as e:
178
- gr.Error(f"Zip failed: {e}")
179
- return None
180
-
181
- def upload_model(self, repo_name: str, oauth_token_str: str) -> str:
182
- """
183
- Calls the model trainer upload function using the session's output directory.
184
- """
185
- if not os.path.exists(self.output_dir):
186
- return "❌ Error: No trained model found in this session. Run training first."
187
- if not repo_name.strip():
188
- return "❌ Error: Please specify a repository name."
189
-
190
- return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str)
191
-
192
-
193
- ## Training Logic ##
194
- def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
195
- total_ids, selected_ids = set(self.number_list), set(selected_ids)
196
- non_selected_ids = total_ids - selected_ids
197
- is_minority = len(selected_ids) < (len(total_ids) / 2)
198
-
199
- anchor_ids, pool_ids = (non_selected_ids, list(selected_ids)) if is_minority else (selected_ids, list(non_selected_ids))
200
-
201
- def get_titles(anchor_id, pool_id):
202
- return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
203
-
204
- if not pool_ids or not anchor_ids:
205
- return [], "", ""
206
-
207
- fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
208
- non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
209
-
210
- hn_dataset = []
211
- pool_cycler = cycle(pool_ids)
212
- for anchor_id in sorted(list(anchor_ids)):
213
- fav, non_fav = get_titles(anchor_id, next(pool_cycler))
214
- hn_dataset.append([self.config.QUERY_ANCHOR, fav, non_fav])
215
-
216
- return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
217
-
218
- def training(self, selected_ids: List[int]) -> str:
219
- if self.model is None:
220
- raise gr.Error("Model not loaded.")
221
- if not selected_ids:
222
- raise gr.Error("Select at least one title.")
223
- if len(selected_ids) == len(self.number_list):
224
- raise gr.Error("Cannot select all titles.")
225
-
226
- hn_dataset, _, _ = self._create_hn_dataset(selected_ids)
227
- self.last_hn_dataset = hn_dataset
228
- final_dataset = self.last_hn_dataset + self.imported_dataset
229
-
230
- if not final_dataset:
231
- raise gr.Error("Dataset is empty.")
232
-
233
- def semantic_search_fn() -> str:
234
- return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
235
-
236
- result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
237
- print(f"[{self.session_id}] Starting Training...")
238
-
239
- train_with_dataset(
240
- model=self.model,
241
- dataset=final_dataset,
242
- output_dir=self.output_dir,
243
- task_name=self.config.TASK_NAME,
244
- search_fn=semantic_search_fn
245
- )
246
-
247
- self._update_vibe_checker()
248
- print(f"[{self.session_id}] Training Complete.")
249
-
250
- result += "### Search (After):\n" + f"{semantic_search_fn()}"
251
- return result
252
-
253
- ## Vibe Check Logic ##
254
- def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
255
- info_text = f"**Session:** {self.session_id[:6]}<br>**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}"
256
-
257
- if not self.vibe_checker:
258
- return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_html("gray")), info_text
259
- if not news_text or len(news_text.split()) < 3:
260
- return "N/A", "Text too short", gr.update(value=self._generate_vibe_html("white")), info_text
261
-
262
- try:
263
- vibe_result = self.vibe_checker.check(news_text)
264
- status = vibe_result.status_html.split('>')[1].split('<')[0]
265
- return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl)), info_text
266
- except Exception as e:
267
- return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_html("gray")), info_text
268
-
269
- def _generate_vibe_html(self, color: str) -> str:
270
- return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
271
-
272
- ## Mood Reader Logic ##
273
- def fetch_and_display_mood_feed(self) -> str:
274
- if not self.vibe_checker:
275
- return "Model not ready. Please wait or reload."
276
-
277
- feed, status = read_hacker_news_rss(self.config)
278
- if not feed or not feed.entries:
279
- return f"**Feed Error:** {status}"
280
-
281
- scored_entries = []
282
- for entry in feed.entries:
283
- title = entry.get('title')
284
- if not title: continue
285
-
286
- vibe_result = self.vibe_checker.check(title)
287
- scored_entries.append({
288
- "title": title,
289
- "link": entry.get('link', '#'),
290
- "comments": entry.get('comments', '#'),
291
- "published": format_published_time(entry.published_parsed),
292
- "mood": vibe_result
293
- })
294
-
295
- scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
296
-
297
- md = (f"## Hacker News Top Stories\n"
298
- f"**Session:** {self.session_id[:6]}<br>"
299
- f"**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}<br>"
300
- f"**Updated:** {datetime.now().strftime('%H:%M:%S')}\n\n"
301
- "| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
302
-
303
- for item in scored_entries:
304
- md += (f"| {item['mood'].status_html} "
305
- f"| {item['mood'].raw_score:.4f} "
306
- f"| [{item['title']}]({item['link']}) "
307
- f"| [Comments]({item['comments']}) "
308
- f"| {item['published']} |\n")
309
- return md
310
-
311
-
312
- # --- Session Wrappers ---
313
-
314
- def refresh_wrapper(app):
315
- """
316
- Initializes the session if it's not already created, then runs the refresh.
317
- Returns the app instance to update the State.
318
- """
319
- if app is None or callable(app) or isinstance(app, type):
320
- print("Initializing new HackerNewsFineTuner session...")
321
- app = HackerNewsFineTuner(AppConfig)
322
-
323
- # Run the refresh logic
324
- update1, update2 = app.refresh_data_and_model()
325
-
326
- # Return 3 items: The App Instance (for State), Choice Update, Text Update
327
- return app, update1, update2
328
-
329
- def import_wrapper(app, file):
330
- return app.import_additional_dataset(file)
331
-
332
- def export_wrapper(app):
333
- return app.export_dataset()
334
-
335
- def download_model_wrapper(app):
336
- return app.download_model()
337
-
338
- def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
339
- """
340
- Wrapper for pushing the model to the Hugging Face Hub.
341
- Gradio automatically injects 'oauth_token' if the user is logged in via LoginButton.
342
- """
343
- if oauth_token is None:
344
- return "⚠️ You must be logged in to push to the Hub. Please sign in above."
345
-
346
- # Extract the token string from the OAuthToken object
347
- token_str = oauth_token.token
348
- return app.upload_model(repo_name, token_str)
349
-
350
- def training_wrapper(app, selected_ids):
351
- return app.training(selected_ids)
352
-
353
- def vibe_check_wrapper(app, text):
354
- return app.get_vibe_check(text)
355
-
356
- def mood_feed_wrapper(app):
357
- return app.fetch_and_display_mood_feed()
358
-
359
-
360
- # --- Interface Setup ---
361
-
362
- def build_interface() -> gr.Blocks:
363
- with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
364
- # Initialize state as None. It will be populated by refresh_wrapper on load.
365
- session_state = gr.State()
366
-
367
- with gr.Column():
368
- gr.Markdown("# πŸ€– EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
369
- gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
370
- gr.LoginButton(value="(Optional) Sign in to Hugging Face, if you want to push fine-tuned model to your repo.")
371
-
372
- with gr.Tab("πŸš€ Fine-Tuning & Evaluation"):
373
- with gr.Column():
374
- gr.Markdown("## Fine-Tuning & Semantic Search\nSelect titles to fine-tune the model towards making them more similar to **`MY_FAVORITE_NEWS`**.")
375
- with gr.Row():
376
- favorite_list = gr.CheckboxGroup(choices=[], type="index", label="Hacker News Top Stories", show_select_all=True)
377
- output = gr.Textbox(lines=14, label="Training and Search Results", value="Loading data...")
378
-
379
- with gr.Row():
380
- clear_reload_btn = gr.Button("Clear & Reload")
381
- run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary")
382
-
383
- gr.Markdown("--- \n ## Dataset & Model Management")
384
- gr.Markdown("To train on your own data, upload a CSV file with the following columns (no header required, or header ignored if present):\n1. **Anchor**: A fixed anchor phrase, `MY_FAVORITE_NEWS`.\n2. **Positive**: A title or contents that you like.\n3. **Negative**: A title or contents that you don't like.\n\nExample CSV Row:\n```\nMY_FAVORITE_NEWS,What is machine learning?,How to write a compiler from scratch.\n```")
385
- import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
386
-
387
- with gr.Row():
388
- download_dataset_btn = gr.Button("πŸ’Ύ Export Dataset")
389
- download_model_btn = gr.Button("⬇️ Download Fine-Tuned Model")
390
-
391
- download_status = gr.Markdown("Ready.")
392
-
393
- with gr.Row():
394
- dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
395
- model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
396
-
397
- gr.Markdown("### ☁️ Publish to Hugging Face Hub")
398
- with gr.Row():
399
- repo_name_input = gr.Textbox(label="Target Repository Name", placeholder="e.g., my-news-vibe-model")
400
- push_to_hub_btn = gr.Button("Push to Hub", variant="secondary")
401
-
402
- push_status = gr.Markdown("")
403
-
404
- # --- Interactions ---
405
-
406
- # 1. Initial Load: Initialize State and Load Data
407
- demo.load(
408
- fn=refresh_wrapper,
409
- inputs=[session_state],
410
- outputs=[session_state, favorite_list, output]
411
- )
412
-
413
- buttons_to_lock = [
414
- clear_reload_btn,
415
- run_training_btn,
416
- download_dataset_btn,
417
- download_model_btn,
418
- push_to_hub_btn
419
- ]
420
-
421
- # 2. Buttons
422
- clear_reload_btn.click(
423
- fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
424
- outputs=buttons_to_lock
425
- ).then(
426
- fn=refresh_wrapper,
427
- inputs=[session_state],
428
- outputs=[session_state, favorite_list, output]
429
- ).then(
430
- fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
431
- outputs=buttons_to_lock
432
- )
433
-
434
- run_training_btn.click(
435
- fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
436
- outputs=buttons_to_lock
437
- ).then(
438
- fn=training_wrapper,
439
- inputs=[session_state, favorite_list],
440
- outputs=[output]
441
- ).then(
442
- fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
443
- outputs=buttons_to_lock
444
- )
445
-
446
- import_file.change(
447
- fn=import_wrapper,
448
- inputs=[session_state, import_file],
449
- outputs=[download_status]
450
- )
451
-
452
- download_dataset_btn.click(
453
- fn=export_wrapper,
454
- inputs=[session_state],
455
- outputs=[dataset_output]
456
- ).then(
457
- lambda p: gr.update(visible=True) if p else gr.update(), inputs=[dataset_output], outputs=[dataset_output]
458
- )
459
-
460
- download_model_btn.click(
461
- fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
462
- outputs=buttons_to_lock
463
- ).then(
464
- lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False
465
- ).then(
466
- fn=download_model_wrapper,
467
- inputs=[session_state],
468
- outputs=[model_output]
469
- ).then(
470
- lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status]
471
- ).then(
472
- fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
473
- outputs=buttons_to_lock
474
- )
475
-
476
- # Push to Hub Interaction
477
- push_to_hub_btn.click(
478
- fn=push_to_hub_wrapper,
479
- inputs=[session_state, repo_name_input],
480
- outputs=[push_status]
481
- )
482
-
483
- with gr.Tab("πŸ“° Hacker News Similarity Check"):
484
- with gr.Column():
485
- gr.Markdown(f"## Live Hacker News Feed Vibe")
486
- gr.Markdown(f"This feed uses the current model (base or fine-tuned) to score the vibe of live Hacker News stories against **`{AppConfig.QUERY_ANCHOR}`**.")
487
- feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
488
- refresh_button = gr.Button("Refresh Feed πŸ”„", size="lg", variant="primary")
489
- refresh_button.click(fn=mood_feed_wrapper, inputs=[session_state], outputs=feed_output)
490
-
491
- with gr.Tab("πŸ’‘ Similarity Lamp"):
492
- with gr.Column():
493
- gr.Markdown(f"## News Similarity Check")
494
- gr.Markdown(f"Enter text to see its similarity to **`{AppConfig.QUERY_ANCHOR}`**.\n**Vibe Key:** Green = High, Red = Low")
495
- news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
496
- vibe_check_btn = gr.Button("Check Similarity", variant="primary")
497
-
498
- gr.Examples(
499
- examples=[
500
- "Global Markets Rally as Inflation Data Shows Unexpected Drop for Third Consecutive Month",
501
- "Astronomers Detect Strong Oxygen Signature on Potentially Habitable Exoplanet",
502
- "City Council Approves Controversial Plan to Ban Cars from Downtown District by 2027",
503
- "Tech Giant Unveils Prototype for \"Invisible\" AR Glasses, Promising a Screen-Free Future",
504
- "Local Library Receives Overdue Book Checked Out in 1948 With An Anonymous Apology Note"
505
- ],
506
- inputs=news_input,
507
- label="Try these examples"
508
- )
509
-
510
- session_info_display = gr.Markdown()
511
-
512
- with gr.Row():
513
- vibe_color_block = gr.HTML(value='<div style="background-color: gray; height: 100px;"></div>', label="Mood Lamp")
514
- with gr.Column():
515
- vibe_score = gr.Textbox(label="Score", value="N/A", interactive=False)
516
- vibe_status = gr.Textbox(label="Status", value="...", interactive=False)
517
-
518
- vibe_check_btn.click(
519
- fn=vibe_check_wrapper,
520
- inputs=[session_state, news_input],
521
- outputs=[vibe_score, vibe_status, vibe_color_block, session_info_display]
522
- )
523
-
524
- return demo
525
 
526
  if __name__ == "__main__":
527
  app_demo = build_interface()
528
  print("Starting Multi-User Gradio App...")
529
  app_demo.queue()
530
- app_demo.launch()
 
1
+ from src.ui import build_interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
  app_demo = build_interface()
5
  print("Starting Multi-User Gradio App...")
6
  app_demo.queue()
7
+ app_demo.launch()
src/session_manager.py ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import shutil
3
+ import time
4
+ import csv
5
+ import uuid
6
+ from itertools import cycle
7
+ from typing import List, Tuple, Optional
8
+ from datetime import datetime
9
+ import gradio as gr # Needed for gr.update, gr.Warning, gr.Info, gr.Error
10
+
11
+ from .data_fetcher import read_hacker_news_rss, format_published_time
12
+ from .model_trainer import (
13
+ authenticate_hf,
14
+ train_with_dataset,
15
+ get_top_hits,
16
+ load_embedding_model,
17
+ upload_model_to_hub
18
+ )
19
+ from .config import AppConfig
20
+ from .vibe_logic import VibeChecker
21
+ from sentence_transformers import SentenceTransformer
22
+
23
+ class HackerNewsFineTuner:
24
+ """
25
+ Encapsulates all application logic and state for a single user session.
26
+ """
27
+
28
+ def __init__(self, config: AppConfig = AppConfig):
29
+ # --- Dependencies ---
30
+ self.config = config
31
+
32
+ # --- Session Identification ---
33
+ self.session_id = str(uuid.uuid4())
34
+
35
+ # Define session-specific paths to allow simultaneous training
36
+ self.session_root = self.config.ARTIFACTS_DIR / self.session_id
37
+ self.output_dir = self.session_root / "embedding_gemma_finetuned"
38
+ self.dataset_export_file = self.session_root / "training_dataset.csv"
39
+
40
+ # Setup directories
41
+ os.makedirs(self.output_dir, exist_ok=True)
42
+ print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}")
43
+
44
+ # --- Application State ---
45
+ self.model: Optional[SentenceTransformer] = None
46
+ self.vibe_checker: Optional[VibeChecker] = None
47
+ self.titles: List[str] = []
48
+ self.target_titles: List[str] = []
49
+ self.number_list: List[int] = []
50
+ self.last_hn_dataset: List[List[str]] = []
51
+ self.imported_dataset: List[List[str]] = []
52
+
53
+ # Authenticate once (global)
54
+ authenticate_hf(self.config.HF_TOKEN)
55
+
56
+ def _update_vibe_checker(self):
57
+ """Initializes or updates the VibeChecker with the current model state."""
58
+ if self.model:
59
+ self.vibe_checker = VibeChecker(
60
+ model=self.model,
61
+ query_anchor=self.config.QUERY_ANCHOR,
62
+ task_name=self.config.TASK_NAME
63
+ )
64
+ else:
65
+ self.vibe_checker = None
66
+
67
+ ## Data and Model Management ##
68
+
69
+ def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
70
+ """
71
+ Reloads model and fetches data.
72
+ """
73
+ print(f"[{self.session_id}] Reloading model and data...")
74
+
75
+ self.last_hn_dataset = []
76
+ self.imported_dataset = []
77
+
78
+ # 1. Reload the base embedding model
79
+ try:
80
+ self.model = load_embedding_model(self.config.MODEL_NAME)
81
+ self._update_vibe_checker()
82
+ except Exception as e:
83
+ error_msg = f"CRITICAL ERROR: Model failed to load. {e}"
84
+ print(error_msg)
85
+ self.model = None
86
+ self._update_vibe_checker()
87
+ return (
88
+ gr.update(choices=[], label="Model Load Failed"),
89
+ gr.update(value=error_msg)
90
+ )
91
+
92
+ # 2. Fetch fresh news data
93
+ news_feed, status_msg = read_hacker_news_rss(self.config)
94
+ titles_out, target_titles_out = [], []
95
+ status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
96
+
97
+ if news_feed is not None and news_feed.entries:
98
+ titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
99
+ target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
100
+ else:
101
+ titles_out = ["Error fetching news.", "Check console."]
102
+ gr.Warning(f"Data reload failed. {status_msg}")
103
+
104
+ self.titles = titles_out
105
+ self.target_titles = target_titles_out
106
+ self.number_list = list(range(len(self.titles)))
107
+
108
+ return (
109
+ gr.update(
110
+ choices=self.titles,
111
+ label=f"Hacker News Top {len(self.titles)} (Select your favorites)"
112
+ ),
113
+ gr.update(value=status_value)
114
+ )
115
+
116
+ # --- Import Dataset/Export ---
117
+ def import_additional_dataset(self, file_path: str) -> str:
118
+ if not file_path:
119
+ return "Please upload a CSV file."
120
+ new_dataset, num_imported = [], 0
121
+ try:
122
+ with open(file_path, 'r', newline='', encoding='utf-8') as f:
123
+ reader = csv.reader(f)
124
+ try:
125
+ header = next(reader)
126
+ if not (header and header[0].lower().strip() == 'anchor'):
127
+ f.seek(0)
128
+ except StopIteration:
129
+ return "Error: Uploaded file is empty."
130
+
131
+ for row in reader:
132
+ if len(row) == 3:
133
+ new_dataset.append([s.strip() for s in row])
134
+ num_imported += 1
135
+ if num_imported == 0:
136
+ raise ValueError("No valid rows found.")
137
+ self.imported_dataset = new_dataset
138
+ return f"Imported {num_imported} triplets."
139
+ except Exception as e:
140
+ return f"Import failed: {e}"
141
+
142
+ def export_dataset(self) -> Optional[str]:
143
+ if not self.last_hn_dataset:
144
+ gr.Warning("No dataset generated yet.")
145
+ return None
146
+
147
+ file_path = self.dataset_export_file
148
+ try:
149
+ with open(file_path, 'w', newline='', encoding='utf-8') as f:
150
+ writer = csv.writer(f)
151
+ writer.writerow(['Anchor', 'Positive', 'Negative'])
152
+ writer.writerows(self.last_hn_dataset)
153
+ gr.Info(f"Dataset exported.")
154
+ return str(file_path)
155
+ except Exception as e:
156
+ gr.Error(f"Export failed: {e}")
157
+ return None
158
+
159
+ def download_model(self) -> Optional[str]:
160
+ if not os.path.exists(self.output_dir):
161
+ gr.Warning("No model trained yet.")
162
+ return None
163
+
164
+ timestamp = int(time.time())
165
+ try:
166
+ base_name = self.session_root / f"model_finetuned_{timestamp}"
167
+ archive_path = shutil.make_archive(
168
+ base_name=str(base_name),
169
+ format='zip',
170
+ root_dir=self.output_dir,
171
+ )
172
+ gr.Info(f"Model zipped.")
173
+ return archive_path
174
+ except Exception as e:
175
+ gr.Error(f"Zip failed: {e}")
176
+ return None
177
+
178
+ def upload_model(self, repo_name: str, oauth_token_str: str) -> str:
179
+ """
180
+ Calls the model trainer upload function using the session's output directory.
181
+ """
182
+ if not os.path.exists(self.output_dir):
183
+ return "❌ Error: No trained model found in this session. Run training first."
184
+ if not repo_name.strip():
185
+ return "❌ Error: Please specify a repository name."
186
+
187
+ return upload_model_to_hub(self.output_dir, repo_name, oauth_token_str)
188
+
189
+
190
+ ## Training Logic ##
191
+ def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
192
+ total_ids, selected_ids = set(self.number_list), set(selected_ids)
193
+ non_selected_ids = total_ids - selected_ids
194
+ is_minority = len(selected_ids) < (len(total_ids) / 2)
195
+
196
+ anchor_ids, pool_ids = (non_selected_ids, list(selected_ids)) if is_minority else (selected_ids, list(non_selected_ids))
197
+
198
+ def get_titles(anchor_id, pool_id):
199
+ return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
200
+
201
+ if not pool_ids or not anchor_ids:
202
+ return [], "", ""
203
+
204
+ fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
205
+ non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
206
+
207
+ hn_dataset = []
208
+ pool_cycler = cycle(pool_ids)
209
+ for anchor_id in sorted(list(anchor_ids)):
210
+ fav, non_fav = get_titles(anchor_id, next(pool_cycler))
211
+ hn_dataset.append([self.config.QUERY_ANCHOR, fav, non_fav])
212
+
213
+ return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
214
+
215
+ def training(self, selected_ids: List[int]) -> str:
216
+ if self.model is None:
217
+ raise gr.Error("Model not loaded.")
218
+ if not selected_ids:
219
+ raise gr.Error("Select at least one title.")
220
+ if len(selected_ids) == len(self.number_list):
221
+ raise gr.Error("Cannot select all titles.")
222
+
223
+ hn_dataset, _, _ = self._create_hn_dataset(selected_ids)
224
+ self.last_hn_dataset = hn_dataset
225
+ final_dataset = self.last_hn_dataset + self.imported_dataset
226
+
227
+ if not final_dataset:
228
+ raise gr.Error("Dataset is empty.")
229
+
230
+ def semantic_search_fn() -> str:
231
+ return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
232
+
233
+ result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
234
+ print(f"[{self.session_id}] Starting Training...")
235
+
236
+ train_with_dataset(
237
+ model=self.model,
238
+ dataset=final_dataset,
239
+ output_dir=self.output_dir,
240
+ task_name=self.config.TASK_NAME,
241
+ search_fn=semantic_search_fn
242
+ )
243
+
244
+ self._update_vibe_checker()
245
+ print(f"[{self.session_id}] Training Complete.")
246
+
247
+ result += "### Search (After):\n" + f"{semantic_search_fn()}"
248
+ return result
249
+
250
+ ## Vibe Check Logic ##
251
+ def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
252
+ info_text = f"**Session:** {self.session_id[:6]}<br>**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}"
253
+
254
+ if not self.vibe_checker:
255
+ return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_html("gray")), info_text
256
+ if not news_text or len(news_text.split()) < 3:
257
+ return "N/A", "Text too short", gr.update(value=self._generate_vibe_html("white")), info_text
258
+
259
+ try:
260
+ vibe_result = self.vibe_checker.check(news_text)
261
+ status = vibe_result.status_html.split('>')[1].split('<')[0]
262
+ return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl)), info_text
263
+ except Exception as e:
264
+ return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_html("gray")), info_text
265
+
266
+ def _generate_vibe_html(self, color: str) -> str:
267
+ return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
268
+
269
+ ## Mood Reader Logic ##
270
+ def fetch_and_display_mood_feed(self) -> str:
271
+ if not self.vibe_checker:
272
+ return "Model not ready. Please wait or reload."
273
+
274
+ feed, status = read_hacker_news_rss(self.config)
275
+ if not feed or not feed.entries:
276
+ return f"**Feed Error:** {status}"
277
+
278
+ scored_entries = []
279
+ for entry in feed.entries:
280
+ title = entry.get('title')
281
+ if not title: continue
282
+
283
+ vibe_result = self.vibe_checker.check(title)
284
+ scored_entries.append({
285
+ "title": title,
286
+ "link": entry.get('link', '#'),
287
+ "comments": entry.get('comments', '#'),
288
+ "published": format_published_time(entry.published_parsed),
289
+ "mood": vibe_result
290
+ })
291
+
292
+ scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
293
+
294
+ md = (f"## Hacker News Top Stories\n"
295
+ f"**Session:** {self.session_id[:6]}<br>"
296
+ f"**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}<br>"
297
+ f"**Updated:** {datetime.now().strftime('%H:%M:%S')}\n\n"
298
+ "| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
299
+
300
+ for item in scored_entries:
301
+ md += (f"| {item['mood'].status_html} "
302
+ f"| {item['mood'].raw_score:.4f} "
303
+ f"| [{item['title']}]({item['link']}) "
304
+ f"| [Comments]({item['comments']}) "
305
+ f"| {item['published']} |\n")
306
+ return md
src/ui.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from typing import Optional
3
+ from datetime import datetime
4
+
5
+ from .config import AppConfig
6
+ from .session_manager import HackerNewsFineTuner
7
+
8
+ # --- Session Wrappers ---
9
+
10
+ def refresh_wrapper(app):
11
+ """
12
+ Initializes the session if it's not already created, then runs the refresh.
13
+ Returns the app instance to update the State.
14
+ """
15
+ if app is None or callable(app) or isinstance(app, type):
16
+ print("Initializing new HackerNewsFineTuner session...")
17
+ app = HackerNewsFineTuner(AppConfig)
18
+
19
+ # Run the refresh logic
20
+ update1, update2 = app.refresh_data_and_model()
21
+
22
+ # Return 3 items: The App Instance (for State), Choice Update, Text Update
23
+ return app, update1, update2
24
+
25
+ def import_wrapper(app, file):
26
+ return app.import_additional_dataset(file)
27
+
28
+ def export_wrapper(app):
29
+ return app.export_dataset()
30
+
31
+ def download_model_wrapper(app):
32
+ return app.download_model()
33
+
34
+ def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
35
+ """
36
+ Wrapper for pushing the model to the Hugging Face Hub.
37
+ Gradio automatically injects 'oauth_token' if the user is logged in via LoginButton.
38
+ """
39
+ if oauth_token is None:
40
+ return "⚠️ You must be logged in to push to the Hub. Please sign in above."
41
+
42
+ # Extract the token string from the OAuthToken object
43
+ token_str = oauth_token.token
44
+ return app.upload_model(repo_name, token_str)
45
+
46
+ def training_wrapper(app, selected_ids):
47
+ return app.training(selected_ids)
48
+
49
+ def vibe_check_wrapper(app, text):
50
+ return app.get_vibe_check(text)
51
+
52
+ def mood_feed_wrapper(app):
53
+ return app.fetch_and_display_mood_feed()
54
+
55
+
56
+ # --- Interface Setup ---
57
+
58
+ def build_interface() -> gr.Blocks:
59
+ with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
60
+ # Initialize state as None. It will be populated by refresh_wrapper on load.
61
+ session_state = gr.State()
62
+
63
+ with gr.Column():
64
+ gr.Markdown("# πŸ€– EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
65
+ gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
66
+ gr.LoginButton(value="(Optional) Sign in to Hugging Face, if you want to push fine-tuned model to your repo.")
67
+
68
+ with gr.Tab("πŸš€ Fine-Tuning & Evaluation"):
69
+ with gr.Column():
70
+ gr.Markdown("## Fine-Tuning & Semantic Search\nSelect titles to fine-tune the model towards making them more similar to **`MY_FAVORITE_NEWS`**.")
71
+ with gr.Row():
72
+ favorite_list = gr.CheckboxGroup(choices=[], type="index", label="Hacker News Top Stories", show_select_all=True)
73
+ output = gr.Textbox(lines=14, label="Training and Search Results", value="Loading data...")
74
+
75
+ with gr.Row():
76
+ clear_reload_btn = gr.Button("Clear & Reload")
77
+ run_training_btn = gr.Button("πŸš€ Run Fine-Tuning", variant="primary")
78
+
79
+ gr.Markdown("--- \n ## Dataset & Model Management")
80
+ gr.Markdown("To train on your own data, upload a CSV file with the following columns (no header required, or header ignored if present):\n1. **Anchor**: A fixed anchor phrase, `MY_FAVORITE_NEWS`.\n2. **Positive**: A title or contents that you like.\n3. **Negative**: A title or contents that you don't like.\n\nExample CSV Row:\n```\nMY_FAVORITE_NEWS,What is machine learning?,How to write a compiler from scratch.\n```")
81
+ import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
82
+
83
+ with gr.Row():
84
+ download_dataset_btn = gr.Button("πŸ’Ύ Export Dataset")
85
+ download_model_btn = gr.Button("⬇️ Download Fine-Tuned Model")
86
+
87
+ download_status = gr.Markdown("Ready.")
88
+
89
+ with gr.Row():
90
+ dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
91
+ model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
92
+
93
+ gr.Markdown("### ☁️ Publish to Hugging Face Hub")
94
+ with gr.Row():
95
+ repo_name_input = gr.Textbox(label="Target Repository Name", placeholder="e.g., my-news-vibe-model")
96
+ push_to_hub_btn = gr.Button("Push to Hub", variant="secondary")
97
+
98
+ push_status = gr.Markdown("")
99
+
100
+ # --- Interactions ---
101
+
102
+ # 1. Initial Load: Initialize State and Load Data
103
+ demo.load(
104
+ fn=refresh_wrapper,
105
+ inputs=[session_state],
106
+ outputs=[session_state, favorite_list, output]
107
+ )
108
+
109
+ buttons_to_lock = [
110
+ clear_reload_btn,
111
+ run_training_btn,
112
+ download_dataset_btn,
113
+ download_model_btn,
114
+ push_to_hub_btn
115
+ ]
116
+
117
+ # 2. Buttons
118
+ clear_reload_btn.click(
119
+ fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
120
+ outputs=buttons_to_lock
121
+ ).then(
122
+ fn=refresh_wrapper,
123
+ inputs=[session_state],
124
+ outputs=[session_state, favorite_list, output]
125
+ ).then(
126
+ fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
127
+ outputs=buttons_to_lock
128
+ )
129
+
130
+ run_training_btn.click(
131
+ fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
132
+ outputs=buttons_to_lock
133
+ ).then(
134
+ fn=training_wrapper,
135
+ inputs=[session_state, favorite_list],
136
+ outputs=[output]
137
+ ).then(
138
+ fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
139
+ outputs=buttons_to_lock
140
+ )
141
+
142
+ import_file.change(
143
+ fn=import_wrapper,
144
+ inputs=[session_state, import_file],
145
+ outputs=[download_status]
146
+ )
147
+
148
+ download_dataset_btn.click(
149
+ fn=export_wrapper,
150
+ inputs=[session_state],
151
+ outputs=[dataset_output]
152
+ ).then(
153
+ lambda p: gr.update(visible=True) if p else gr.update(), inputs=[dataset_output], outputs=[dataset_output]
154
+ )
155
+
156
+ download_model_btn.click(
157
+ fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
158
+ outputs=buttons_to_lock
159
+ ).then(
160
+ lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False
161
+ ).then(
162
+ fn=download_model_wrapper,
163
+ inputs=[session_state],
164
+ outputs=[model_output]
165
+ ).then(
166
+ lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status]
167
+ ).then(
168
+ fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
169
+ outputs=buttons_to_lock
170
+ )
171
+
172
+ # Push to Hub Interaction
173
+ push_to_hub_btn.click(
174
+ fn=push_to_hub_wrapper,
175
+ inputs=[session_state, repo_name_input],
176
+ outputs=[push_status]
177
+ )
178
+
179
+ with gr.Tab("πŸ“° Hacker News Similarity Check"):
180
+ with gr.Column():
181
+ gr.Markdown(f"## Live Hacker News Feed Vibe")
182
+ gr.Markdown(f"This feed uses the current model (base or fine-tuned) to score the vibe of live Hacker News stories against **`{AppConfig.QUERY_ANCHOR}`**.")
183
+ feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
184
+ refresh_button = gr.Button("Refresh Feed πŸ”„", size="lg", variant="primary")
185
+ refresh_button.click(fn=mood_feed_wrapper, inputs=[session_state], outputs=feed_output)
186
+
187
+ with gr.Tab("πŸ’‘ Similarity Lamp"):
188
+ with gr.Column():
189
+ gr.Markdown(f"## News Similarity Check")
190
+ gr.Markdown(f"Enter text to see its similarity to **`{AppConfig.QUERY_ANCHOR}`**.\n**Vibe Key:** Green = High, Red = Low")
191
+ news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
192
+ vibe_check_btn = gr.Button("Check Similarity", variant="primary")
193
+
194
+ gr.Examples(
195
+ examples=[
196
+ "Global Markets Rally as Inflation Data Shows Unexpected Drop for Third Consecutive Month",
197
+ "Astronomers Detect Strong Oxygen Signature on Potentially Habitable Exoplanet",
198
+ "City Council Approves Controversial Plan to Ban Cars from Downtown District by 2027",
199
+ "Tech Giant Unveils Prototype for \"Invisible\" AR Glasses, Promising a Screen-Free Future",
200
+ "Local Library Receives Overdue Book Checked Out in 1948 With An Anonymous Apology Note"
201
+ ],
202
+ inputs=news_input,
203
+ label="Try these examples"
204
+ )
205
+
206
+ session_info_display = gr.Markdown()
207
+
208
+ with gr.Row():
209
+ vibe_color_block = gr.HTML(value='<div style="background-color: gray; height: 100px;"></div>', label="Mood Lamp")
210
+ with gr.Column():
211
+ vibe_score = gr.Textbox(label="Score", value="N/A", interactive=False)
212
+ vibe_status = gr.Textbox(label="Status", value="...", interactive=False)
213
+
214
+ vibe_check_btn.click(
215
+ fn=vibe_check_wrapper,
216
+ inputs=[session_state, news_input],
217
+ outputs=[vibe_score, vibe_status, vibe_color_block, session_info_display]
218
+ )
219
+
220
+ return demo
src/vibe_logic.py CHANGED
@@ -23,7 +23,7 @@ VIBE_THRESHOLDS: List[VibeThreshold] = [
23
  VibeThreshold(score=0.8, status="✨ VIBE:HIGH"),
24
  VibeThreshold(score=0.5, status="πŸ‘ VIBE:GOOD"),
25
  VibeThreshold(score=0.2, status="😐 VIBE:FLAT"),
26
- VibeThreshold(score=0.0, status="πŸ‘Ž VIBE:LOW&nbsp;"), # Base case for scores < 0.2
27
  ]
28
 
29
  # --- Utility Functions ---
 
23
  VibeThreshold(score=0.8, status="✨ VIBE:HIGH"),
24
  VibeThreshold(score=0.5, status="πŸ‘ VIBE:GOOD"),
25
  VibeThreshold(score=0.2, status="😐 VIBE:FLAT"),
26
+ VibeThreshold(score=0.0, status="πŸ‘Ž VIBE:LOW"), # Base case for scores < 0.2
27
  ]
28
 
29
  # --- Utility Functions ---