bebechien commited on
Commit
bb18a0f
·
verified ·
1 Parent(s): 96475d9

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +217 -171
app.py CHANGED
@@ -3,6 +3,7 @@ import os
3
  import shutil
4
  import time
5
  import csv
 
6
  from itertools import cycle
7
  from typing import List, Iterable, Tuple, Optional, Callable
8
  from datetime import datetime
@@ -19,46 +20,47 @@ from config import AppConfig
19
  from vibe_logic import VibeChecker
20
  from sentence_transformers import SentenceTransformer
21
 
22
- # --- Main Application Class ---
23
 
24
  class HackerNewsFineTuner:
25
  """
26
- Encapsulates all application logic and state for the Gradio interface.
27
- Manages the embedding model, news data, and training datasets.
28
  """
29
 
30
  def __init__(self, config: AppConfig = AppConfig):
31
  # --- Dependencies ---
32
  self.config = config
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
  # --- Application State ---
35
  self.model: Optional[SentenceTransformer] = None
36
  self.vibe_checker: Optional[VibeChecker] = None
37
- self.titles: List[str] = [] # Top titles for user selection
38
- self.target_titles: List[str] = [] # Remaining titles for semantic search target pool
39
- self.number_list: List[int] = [] # [0, 1, 2, ...] for checkbox indexing
40
- self.last_hn_dataset: List[List[str]] = [] # Last generated dataset from HN selection
41
- self.imported_dataset: List[List[str]] = [] # Manually imported dataset
42
-
43
- # Setup
44
- os.makedirs(self.config.ARTIFACTS_DIR, exist_ok=True)
45
- print(f"Created artifact directory: {self.config.ARTIFACTS_DIR}")
46
-
47
- authenticate_hf(self.config.HF_TOKEN)
48
 
49
- # Load initial data on startup
50
- self._initial_load()
51
 
52
- def _initial_load(self):
53
- """Helper to run the refresh function once at startup."""
54
- print("--- Running Initial Data Load ---")
55
- self.refresh_data_and_model()
56
- print("--- Initial Load Complete ---")
57
 
58
  def _update_vibe_checker(self):
59
  """Initializes or updates the VibeChecker with the current model state."""
60
  if self.model:
61
- print("Updating VibeChecker instance with the current model.")
62
  self.vibe_checker = VibeChecker(
63
  model=self.model,
64
  query_anchor=self.config.QUERY_ANCHOR,
@@ -71,14 +73,10 @@ class HackerNewsFineTuner:
71
 
72
  def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
73
  """
74
- 1. Reloads the embedding model to clear fine-tuning.
75
- 2. Fetches fresh news data (from cache or web).
76
- 3. Updates the class state and returns Gradio updates for the UI.
77
  """
78
- print("\n" + "=" * 50)
79
- print("RELOADING MODEL and RE-FETCHING DATA")
80
 
81
- # Reset dataset state
82
  self.last_hn_dataset = []
83
  self.imported_dataset = []
84
 
@@ -87,33 +85,32 @@ class HackerNewsFineTuner:
87
  self.model = load_embedding_model(self.config.MODEL_NAME)
88
  self._update_vibe_checker()
89
  except Exception as e:
90
- gr.Error(f"Model load failed: {e}")
 
91
  self.model = None
92
  self._update_vibe_checker()
93
  return (
94
  gr.update(choices=[], label="Model Load Failed"),
95
- gr.update(value=f"CRITICAL ERROR: Model failed to load. {e}")
96
  )
97
 
98
  # 2. Fetch fresh news data
 
99
  news_feed, status_msg = read_hacker_news_rss(self.config)
100
  titles_out, target_titles_out = [], []
101
- status_value: str = f"Model and data reloaded. Status: {status_msg}. Click 'Run Fine-Tuning' to begin."
102
 
103
  if news_feed is not None and news_feed.entries:
104
- # Use constant for clarity
105
  titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
106
  target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
107
- print(f"Data reloaded: {len(titles_out)} selection titles, {len(target_titles_out)} target titles.")
108
  else:
109
- titles_out = ["Error fetching news, check console.", "Could not load feed.", "No data available."]
110
- gr.Warning(f"Data reload failed. Using error placeholders. Details: {status_msg}")
111
 
112
  self.titles = titles_out
113
  self.target_titles = target_titles_out
114
  self.number_list = list(range(len(self.titles)))
115
 
116
- # Return Gradio updates for CheckboxGroup and Textbox
117
  return (
118
  gr.update(
119
  choices=self.titles,
@@ -142,55 +139,52 @@ class HackerNewsFineTuner:
142
  new_dataset.append([s.strip() for s in row])
143
  num_imported += 1
144
  if num_imported == 0:
145
- raise ValueError("No valid [Anchor, Positive, Negative] rows found in the CSV.")
146
  self.imported_dataset = new_dataset
147
- return f"Successfully imported {num_imported} additional training triplets."
148
  except Exception as e:
149
- gr.Error(f"Import failed. Ensure the CSV format is: [Anchor, Positive, Negative]. Error: {e}")
150
- return "Import failed. Check console for details."
151
 
152
  def export_dataset(self) -> Optional[str]:
153
  if not self.last_hn_dataset:
154
- gr.Warning("No dataset has been generated from current selection yet. Please run fine-tuning first.")
155
  return None
156
- file_path = self.config.DATASET_EXPORT_FILENAME
 
 
157
  try:
158
- print(f"Exporting dataset to {file_path}...")
159
  with open(file_path, 'w', newline='', encoding='utf-8') as f:
160
  writer = csv.writer(f)
161
  writer.writerow(['Anchor', 'Positive', 'Negative'])
162
  writer.writerows(self.last_hn_dataset)
163
- gr.Info(f"Dataset successfully exported to {file_path}")
164
  return str(file_path)
165
  except Exception as e:
166
- gr.Error(f"Failed to export the dataset to CSV. Error: {e}")
167
  return None
168
 
169
  def download_model(self) -> Optional[str]:
170
- if not os.path.exists(self.config.OUTPUT_DIR):
171
- gr.Warning(f"The model directory '{self.config.OUTPUT_DIR}' does not exist. Please run training first.")
172
  return None
 
173
  timestamp = int(time.time())
174
  try:
175
- base_name = os.path.join(self.config.ARTIFACTS_DIR, f"embedding_gemma_finetuned_{timestamp}")
 
176
  archive_path = shutil.make_archive(
177
- base_name=base_name,
178
  format='zip',
179
- root_dir=self.config.OUTPUT_DIR,
180
  )
181
- gr.Info(f"Model files successfully zipped to: {archive_path}")
182
  return archive_path
183
  except Exception as e:
184
- gr.Error(f"Failed to create the model ZIP file. Error: {e}")
185
  return None
186
 
187
  ## Training Logic ##
188
  def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
189
- """
190
- Internal function to generate the [Anchor, Positive, Negative] triplets
191
- from the user's Hacker News title selection.
192
- Returns (dataset, favorite_title, non_favorite_title)
193
- """
194
  total_ids, selected_ids = set(self.number_list), set(selected_ids)
195
  non_selected_ids = total_ids - selected_ids
196
  is_minority = len(selected_ids) < (len(total_ids) / 2)
@@ -200,6 +194,9 @@ class HackerNewsFineTuner:
200
  def get_titles(anchor_id, pool_id):
201
  return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
202
 
 
 
 
203
  fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
204
  non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
205
 
@@ -212,63 +209,66 @@ class HackerNewsFineTuner:
212
  return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
213
 
214
  def training(self, selected_ids: List[int]) -> str:
215
- """
216
- Generates a training dataset from user selection and runs the fine-tuning process.
217
- """
218
  if self.model is None:
219
- raise gr.Error("Training failed: Embedding model is not loaded.")
220
  if not selected_ids:
221
- raise gr.Error("You must select at least one title.")
222
  if len(selected_ids) == len(self.number_list):
223
- raise gr.Error("You can't select all titles; a non-favorite is needed.")
224
 
225
- hn_dataset, example_fav, _ = self._create_hn_dataset(selected_ids)
226
  self.last_hn_dataset = hn_dataset
227
  final_dataset = self.last_hn_dataset + self.imported_dataset
 
228
  if not final_dataset:
229
- raise gr.Error("Training failed: Final dataset is empty.")
230
- print(f"Combined dataset size: {len(final_dataset)} triplets.")
231
 
232
  def semantic_search_fn() -> str:
233
  return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
234
 
235
- result = "### Semantic Search Results (Before Training):\n" + f"{semantic_search_fn()}\n\n"
236
- print("-" * 50 + "\nStarting Fine-tuning...")
237
- train_with_dataset(model=self.model, dataset=final_dataset, output_dir=self.config.OUTPUT_DIR, task_name=self.config.TASK_NAME, search_fn=semantic_search_fn)
 
 
 
 
 
 
 
 
 
238
  self._update_vibe_checker()
239
- print("Fine-tuning Complete.\n" + "-" * 50)
240
 
241
- result += "### Semantic Search Results (After Training):\n" + f"{semantic_search_fn()}"
242
  return result
243
 
244
- ## Vibe Check Logic (Tab 2) ##
245
  def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
246
  if not self.vibe_checker:
247
- gr.Error("Model/VibeChecker not loaded.")
248
- return "N/A", "Model Error", gr.update(value=self._generate_vibe_html("gray"))
249
  if not news_text or len(news_text.split()) < 3:
250
- gr.Warning("Please enter a longer text for a meaningful check.")
251
- return "N/A", "Please enter text", gr.update(value=self._generate_vibe_html("white"))
252
 
253
  try:
254
  vibe_result = self.vibe_checker.check(news_text)
255
- status = vibe_result.status_html.split('>')[1].split('<')[0] # Extract text from HTML
256
  return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl))
257
  except Exception as e:
258
- gr.Error(f"Vibe check failed. Error: {e}")
259
- return "N/A", f"Processing Error: {e}", gr.update(value=self._generate_vibe_html("gray"))
260
 
261
  def _generate_vibe_html(self, color: str) -> str:
262
  return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
263
 
264
- ## Mood Reader Logic (Tab 3) ##
265
  def fetch_and_display_mood_feed(self) -> str:
266
  if not self.vibe_checker:
267
- return "**FATAL ERROR:** The Mood Reader failed to initialize. Check console."
268
 
269
  feed, status = read_hacker_news_rss(self.config)
270
  if not feed or not feed.entries:
271
- return f"**An error occurred while fetching the feed:** {status}"
272
 
273
  scored_entries = []
274
  for entry in feed.entries:
@@ -286,8 +286,8 @@ class HackerNewsFineTuner:
286
 
287
  scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
288
 
289
- md = (f"## Hacker News Top Stories (Model: `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}) ⬇️\n"
290
- f"**Last Updated:** {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n"
291
  "| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
292
 
293
  for item in scored_entries:
@@ -297,94 +297,140 @@ class HackerNewsFineTuner:
297
  f"| [Comments]({item['comments']}) "
298
  f"| {item['published']} |\n")
299
  return md
300
- # 🤖 Embedding Gemma Modkit: Fine-Tuning and Mood Reader
301
-
302
- ## Gradio Interface Setup ##
303
- def build_interface(self) -> gr.Blocks:
304
- with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
305
- gr.Markdown("# 🤖 EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
306
- gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
307
- with gr.Tab("🚀 Fine-Tuning & Evaluation"):
308
- self._build_training_interface()
309
- with gr.Tab("📰 Hacker News Mood Reader"):
310
- self._build_mood_reader_interface()
311
- with gr.Tab("💡 Similarity Check"):
312
- self._build_vibe_check_interface()
313
- return demo
314
-
315
- def _build_training_interface(self):
316
- with gr.Column():
317
- gr.Markdown("## Fine-Tuning & Semantic Search\nSelect titles to fine-tune the model towards making them more similar to **`MY_FAVORITE_NEWS`**.")
318
- with gr.Row():
319
- favorite_list = gr.CheckboxGroup(self.titles, type="index", label=f"Hacker News Top {len(self.titles)}", show_select_all=True)
320
- output = gr.Textbox(lines=14, label="Training and Search Results", value="Click 'Run Fine-Tuning' to begin.")
321
- with gr.Row():
322
- clear_reload_btn = gr.Button("Clear & Reload Model/Data")
323
- run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary")
324
- gr.Markdown("--- \n ## Dataset & Model Management")
325
- gr.Markdown("To train on your own data, upload a CSV file with the following columns (no header required, or header ignored if present):\n1. **Anchor**: A fixed anchor phrase, `MY_FAVORITE_NEWS`.\n2. **Positive**: A title or contents that you like.\n3. **Negative**: A title or contents that you don't like.\n\nExample CSV Row:\n```\nMY_FAVORITE_NEWS,What is machine learning?,How to write a compiler from scratch.\n```")
326
- import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
327
- with gr.Row():
328
- download_dataset_btn = gr.Button("💾 Export Last HN Dataset")
329
- download_model_btn = gr.Button("⬇️ Download Fine-Tuned Model")
330
- download_status = gr.Markdown("Ready.")
331
- with gr.Row():
332
- dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
333
- model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
334
-
335
- buttons_to_lock = [
336
- clear_reload_btn,
337
- run_training_btn,
338
- download_dataset_btn,
339
- download_model_btn
340
- ]
341
-
342
- run_training_btn.click(
343
- fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
344
- outputs=buttons_to_lock
345
- ).then(
346
- fn=self.training, inputs=favorite_list, outputs=output
347
- ).then(
348
- fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
349
- outputs=buttons_to_lock
350
- )
351
- clear_reload_btn.click(fn=self.refresh_data_and_model, inputs=None, outputs=[favorite_list, output], queue=False)
352
- import_file.change(fn=self.import_additional_dataset, inputs=[import_file], outputs=download_status)
353
- download_dataset_btn.click(lambda: [gr.update(value=None, visible=False), "Generating..."], None, [dataset_output, download_status], queue=False).then(self.export_dataset, None, dataset_output).then(lambda p: [gr.update(visible=p is not None, value=p), "CSV ready." if p else "Export failed."], [dataset_output], [dataset_output, download_status])
354
- download_model_btn.click(
355
- fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
356
- outputs=buttons_to_lock
357
- ).then(
358
- lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False
359
- ).then(self.download_model, None, model_output).then(lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status]
360
- ).then(
361
- fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
362
- outputs=buttons_to_lock
363
- )
364
 
365
- def _build_vibe_check_interface(self):
366
- with gr.Column():
367
- gr.Markdown(f"## News Vibe Check Mood Lamp\nEnter text to see its similarity to **`{self.config.QUERY_ANCHOR}`**.\n**Vibe Key:** Green = High, Red = Low")
368
- news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
369
- vibe_check_btn = gr.Button("Check Vibe", variant="primary")
370
- with gr.Row():
371
- vibe_color_block = gr.HTML(value=self._generate_vibe_html("white"), label="Mood Lamp")
372
- with gr.Column():
373
- vibe_score = gr.Textbox(label="Cosine Similarity Score", value="N/A", interactive=False)
374
- vibe_status = gr.Textbox(label="Vibe Status", value="Enter text and click 'Check Vibe'", interactive=False, lines=2)
375
- vibe_check_btn.click(fn=self.get_vibe_check, inputs=[news_input], outputs=[vibe_score, vibe_status, vibe_color_block])
376
-
377
- def _build_mood_reader_interface(self):
378
- with gr.Column():
379
- gr.Markdown(f"## Live Hacker News Feed Vibe\nThis feed uses the current model (base or fine-tuned) to score the vibe of live HN stories against **`{self.config.QUERY_ANCHOR}`**.")
380
- feed_output = gr.Markdown(value="Click 'Refresh Feed' to load stories.", label="Latest Stories")
381
- refresh_button = gr.Button("Refresh Feed 🔄", size="lg", variant="primary")
382
- refresh_button.click(fn=self.fetch_and_display_mood_feed, inputs=None, outputs=feed_output)
383
 
 
 
384
 
385
- if __name__ == "__main__":
386
- app = HackerNewsFineTuner(AppConfig)
387
- demo = app.build_interface()
388
- print("Starting Gradio App...")
389
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
390
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import shutil
4
  import time
5
  import csv
6
+ import uuid
7
  from itertools import cycle
8
  from typing import List, Iterable, Tuple, Optional, Callable
9
  from datetime import datetime
 
20
  from vibe_logic import VibeChecker
21
  from sentence_transformers import SentenceTransformer
22
 
23
+ # --- Main Application Class (Session Scoped) ---
24
 
25
  class HackerNewsFineTuner:
26
  """
27
+ Encapsulates all application logic and state for a single user session.
 
28
  """
29
 
30
  def __init__(self, config: AppConfig = AppConfig):
31
  # --- Dependencies ---
32
  self.config = config
33
+
34
+ # --- Session Identification ---
35
+ self.session_id = str(uuid.uuid4())
36
+
37
+ # Define session-specific paths to allow simultaneous training
38
+ self.session_root = self.config.ARTIFACTS_DIR / self.session_id
39
+ self.output_dir = self.session_root / "embedding_gemma_finetuned"
40
+ self.dataset_export_file = self.session_root / "training_dataset.csv"
41
+
42
+ # Setup directories
43
+ os.makedirs(self.output_dir, exist_ok=True)
44
+ print(f"[{self.session_id}] New session started. Artifacts: {self.session_root}")
45
 
46
  # --- Application State ---
47
  self.model: Optional[SentenceTransformer] = None
48
  self.vibe_checker: Optional[VibeChecker] = None
49
+ self.titles: List[str] = []
50
+ self.target_titles: List[str] = []
51
+ self.number_list: List[int] = []
52
+ self.last_hn_dataset: List[List[str]] = []
53
+ self.imported_dataset: List[List[str]] = []
 
 
 
 
 
 
54
 
55
+ # Authenticate once (global)
56
+ authenticate_hf(self.config.HF_TOKEN)
57
 
58
+ # Note: We do NOT load data here immediately to keep init fast.
59
+ # Data is loaded via the demo.load event.
 
 
 
60
 
61
  def _update_vibe_checker(self):
62
  """Initializes or updates the VibeChecker with the current model state."""
63
  if self.model:
 
64
  self.vibe_checker = VibeChecker(
65
  model=self.model,
66
  query_anchor=self.config.QUERY_ANCHOR,
 
73
 
74
  def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
75
  """
76
+ Reloads model and fetches data.
 
 
77
  """
78
+ print(f"[{self.session_id}] Reloading model and data...")
 
79
 
 
80
  self.last_hn_dataset = []
81
  self.imported_dataset = []
82
 
 
85
  self.model = load_embedding_model(self.config.MODEL_NAME)
86
  self._update_vibe_checker()
87
  except Exception as e:
88
+ error_msg = f"CRITICAL ERROR: Model failed to load. {e}"
89
+ print(error_msg)
90
  self.model = None
91
  self._update_vibe_checker()
92
  return (
93
  gr.update(choices=[], label="Model Load Failed"),
94
+ gr.update(value=error_msg)
95
  )
96
 
97
  # 2. Fetch fresh news data
98
+ # Note: Cache file is shared (global), which is fine/desired for RSS data.
99
  news_feed, status_msg = read_hacker_news_rss(self.config)
100
  titles_out, target_titles_out = [], []
101
+ status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
102
 
103
  if news_feed is not None and news_feed.entries:
 
104
  titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
105
  target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
 
106
  else:
107
+ titles_out = ["Error fetching news.", "Check console."]
108
+ gr.Warning(f"Data reload failed. {status_msg}")
109
 
110
  self.titles = titles_out
111
  self.target_titles = target_titles_out
112
  self.number_list = list(range(len(self.titles)))
113
 
 
114
  return (
115
  gr.update(
116
  choices=self.titles,
 
139
  new_dataset.append([s.strip() for s in row])
140
  num_imported += 1
141
  if num_imported == 0:
142
+ raise ValueError("No valid rows found.")
143
  self.imported_dataset = new_dataset
144
+ return f"Imported {num_imported} triplets."
145
  except Exception as e:
146
+ return f"Import failed: {e}"
 
147
 
148
  def export_dataset(self) -> Optional[str]:
149
  if not self.last_hn_dataset:
150
+ gr.Warning("No dataset generated yet.")
151
  return None
152
+
153
+ # Use session-specific path
154
+ file_path = self.dataset_export_file
155
  try:
 
156
  with open(file_path, 'w', newline='', encoding='utf-8') as f:
157
  writer = csv.writer(f)
158
  writer.writerow(['Anchor', 'Positive', 'Negative'])
159
  writer.writerows(self.last_hn_dataset)
160
+ gr.Info(f"Dataset exported.")
161
  return str(file_path)
162
  except Exception as e:
163
+ gr.Error(f"Export failed: {e}")
164
  return None
165
 
166
  def download_model(self) -> Optional[str]:
167
+ if not os.path.exists(self.output_dir):
168
+ gr.Warning("No model trained yet.")
169
  return None
170
+
171
  timestamp = int(time.time())
172
  try:
173
+ # Create zip in the session folder
174
+ base_name = self.session_root / f"model_finetuned_{timestamp}"
175
  archive_path = shutil.make_archive(
176
+ base_name=str(base_name),
177
  format='zip',
178
+ root_dir=self.output_dir,
179
  )
180
+ gr.Info(f"Model zipped.")
181
  return archive_path
182
  except Exception as e:
183
+ gr.Error(f"Zip failed: {e}")
184
  return None
185
 
186
  ## Training Logic ##
187
  def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
 
 
 
 
 
188
  total_ids, selected_ids = set(self.number_list), set(selected_ids)
189
  non_selected_ids = total_ids - selected_ids
190
  is_minority = len(selected_ids) < (len(total_ids) / 2)
 
194
  def get_titles(anchor_id, pool_id):
195
  return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
196
 
197
+ if not pool_ids or not anchor_ids:
198
+ return [], "", "" # Should be caught by validation
199
+
200
  fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
201
  non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
202
 
 
209
  return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
210
 
211
  def training(self, selected_ids: List[int]) -> str:
 
 
 
212
  if self.model is None:
213
+ raise gr.Error("Model not loaded.")
214
  if not selected_ids:
215
+ raise gr.Error("Select at least one title.")
216
  if len(selected_ids) == len(self.number_list):
217
+ raise gr.Error("Cannot select all titles.")
218
 
219
+ hn_dataset, _, _ = self._create_hn_dataset(selected_ids)
220
  self.last_hn_dataset = hn_dataset
221
  final_dataset = self.last_hn_dataset + self.imported_dataset
222
+
223
  if not final_dataset:
224
+ raise gr.Error("Dataset is empty.")
 
225
 
226
  def semantic_search_fn() -> str:
227
  return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
228
 
229
+ result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
230
+ print(f"[{self.session_id}] Starting Training...")
231
+
232
+ # Use session-specific output directory
233
+ train_with_dataset(
234
+ model=self.model,
235
+ dataset=final_dataset,
236
+ output_dir=self.output_dir,
237
+ task_name=self.config.TASK_NAME,
238
+ search_fn=semantic_search_fn
239
+ )
240
+
241
  self._update_vibe_checker()
242
+ print(f"[{self.session_id}] Training Complete.")
243
 
244
+ result += "### Search (After):\n" + f"{semantic_search_fn()}"
245
  return result
246
 
247
+ ## Vibe Check Logic ##
248
  def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
249
  if not self.vibe_checker:
250
+ return "N/A", "Model Loading...", gr.update(value=self._generate_vibe_html("gray"))
 
251
  if not news_text or len(news_text.split()) < 3:
252
+ return "N/A", "Text too short", gr.update(value=self._generate_vibe_html("white"))
 
253
 
254
  try:
255
  vibe_result = self.vibe_checker.check(news_text)
256
+ status = vibe_result.status_html.split('>')[1].split('<')[0]
257
  return f"{vibe_result.raw_score:.4f}", status, gr.update(value=self._generate_vibe_html(vibe_result.color_hsl))
258
  except Exception as e:
259
+ return "N/A", f"Error: {e}", gr.update(value=self._generate_vibe_html("gray"))
 
260
 
261
  def _generate_vibe_html(self, color: str) -> str:
262
  return f'<div style="background-color: {color}; height: 100px; border-radius: 12px; border: 2px solid #ccc;"></div>'
263
 
264
+ ## Mood Reader Logic ##
265
  def fetch_and_display_mood_feed(self) -> str:
266
  if not self.vibe_checker:
267
+ return "Model not ready. Please wait or reload."
268
 
269
  feed, status = read_hacker_news_rss(self.config)
270
  if not feed or not feed.entries:
271
+ return f"**Feed Error:** {status}"
272
 
273
  scored_entries = []
274
  for entry in feed.entries:
 
286
 
287
  scored_entries.sort(key=lambda x: x["mood"].raw_score, reverse=True)
288
 
289
+ md = (f"## Hacker News Mood (Session: {self.session_id[:6]})\n"
290
+ f"**Updated:** {datetime.now().strftime('%H:%M:%S')}\n\n"
291
  "| Vibe | Score | Title | Comments | Published |\n|---|---|---|---|---|\n")
292
 
293
  for item in scored_entries:
 
297
  f"| [Comments]({item['comments']}) "
298
  f"| {item['published']} |\n")
299
  return md
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
300
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
301
 
302
+ # --- Session Wrappers ---
303
+ # These functions act as bridges between Gradio inputs and the session object.
304
 
305
+ def create_session():
306
+ """Factory to create a new session object."""
307
+ return HackerNewsFineTuner(AppConfig)
308
+
309
+ def refresh_wrapper(app):
310
+ return app.refresh_data_and_model()
311
+
312
+ def import_wrapper(app, file):
313
+ return app.import_additional_dataset(file)
314
+
315
+ def export_wrapper(app):
316
+ return app.export_dataset()
317
+
318
+ def download_model_wrapper(app):
319
+ return app.download_model()
320
+
321
+ def training_wrapper(app, selected_ids):
322
+ return app.training(selected_ids)
323
 
324
+ def vibe_check_wrapper(app, text):
325
+ return app.get_vibe_check(text)
326
+
327
+ def mood_feed_wrapper(app):
328
+ return app.fetch_and_display_mood_feed()
329
+
330
+
331
+ # --- Interface Setup ---
332
+
333
+ def build_interface() -> gr.Blocks:
334
+ with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
335
+ # State object holds the user-specific instance of HackerNewsFineTuner
336
+ session_state = gr.State(create_session)
337
+
338
+ gr.Markdown("# 🤖 EmbeddingGemma Modkit: Fine-Tuning and Mood Reader (Multi-User)")
339
+ gr.Markdown("Each browser tab creates a unique session with isolated training data and models.")
340
+
341
+ with gr.Tab("🚀 Fine-Tuning & Evaluation"):
342
+ with gr.Column():
343
+ gr.Markdown("## Fine-Tuning & Semantic Search")
344
+ with gr.Row():
345
+ # Choices are populated on load via refresh_wrapper
346
+ favorite_list = gr.CheckboxGroup(choices=[], type="index", label="Hacker News Top Stories", show_select_all=True)
347
+ output = gr.Textbox(lines=14, label="Training and Search Results", value="Loading data...")
348
+
349
+ with gr.Row():
350
+ clear_reload_btn = gr.Button("Clear & Reload")
351
+ run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary")
352
+
353
+ gr.Markdown("--- \n ## Dataset & Model Management")
354
+ import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
355
+
356
+ with gr.Row():
357
+ download_dataset_btn = gr.Button("💾 Export Dataset")
358
+ download_model_btn = gr.Button("⬇️ Download Model")
359
+
360
+ download_status = gr.Markdown("Ready.")
361
+
362
+ with gr.Row():
363
+ dataset_output = gr.File(label="Dataset CSV", height=50, visible=False, interactive=False)
364
+ model_output = gr.File(label="Model ZIP", height=50, visible=False, interactive=False)
365
+
366
+ # Interactions
367
+ # Note: We pass session_state as the first input to all wrappers
368
+
369
+ # 1. Initial Load
370
+ demo.load(fn=refresh_wrapper, inputs=[session_state], outputs=[favorite_list, output])
371
+
372
+ # 2. Buttons
373
+ clear_reload_btn.click(
374
+ fn=refresh_wrapper,
375
+ inputs=[session_state],
376
+ outputs=[favorite_list, output]
377
+ )
378
+
379
+ run_training_btn.click(
380
+ fn=training_wrapper,
381
+ inputs=[session_state, favorite_list],
382
+ outputs=[output]
383
+ )
384
+
385
+ import_file.change(
386
+ fn=import_wrapper,
387
+ inputs=[session_state, import_file],
388
+ outputs=[download_status]
389
+ )
390
+
391
+ download_dataset_btn.click(
392
+ fn=export_wrapper,
393
+ inputs=[session_state],
394
+ outputs=[dataset_output]
395
+ ).then(
396
+ lambda p: gr.update(visible=True) if p else gr.update(), inputs=[dataset_output], outputs=[dataset_output]
397
+ )
398
+
399
+ download_model_btn.click(
400
+ fn=download_model_wrapper,
401
+ inputs=[session_state],
402
+ outputs=[model_output]
403
+ ).then(
404
+ lambda p: gr.update(visible=True) if p else gr.update(), inputs=[model_output], outputs=[model_output]
405
+ )
406
+
407
+ with gr.Tab("📰 Hacker News Mood Reader"):
408
+ with gr.Column():
409
+ gr.Markdown(f"## Live Hacker News Feed Vibe")
410
+ feed_output = gr.Markdown(value="Click 'Refresh Feed'...", label="Latest Stories")
411
+ refresh_button = gr.Button("Refresh Feed 🔄", size="lg", variant="primary")
412
+ refresh_button.click(fn=mood_feed_wrapper, inputs=[session_state], outputs=feed_output)
413
+
414
+ with gr.Tab("💡 Similarity Check"):
415
+ with gr.Column():
416
+ gr.Markdown(f"## News Vibe Check Mood Lamp")
417
+ news_input = gr.Textbox(label="Enter News Title or Summary", lines=3)
418
+ vibe_check_btn = gr.Button("Check Vibe", variant="primary")
419
+ with gr.Row():
420
+ vibe_color_block = gr.HTML(value='<div style="background-color: gray; height: 100px;"></div>', label="Mood Lamp")
421
+ with gr.Column():
422
+ vibe_score = gr.Textbox(label="Score", value="N/A", interactive=False)
423
+ vibe_status = gr.Textbox(label="Status", value="...", interactive=False)
424
+
425
+ vibe_check_btn.click(
426
+ fn=vibe_check_wrapper,
427
+ inputs=[session_state, news_input],
428
+ outputs=[vibe_score, vibe_status, vibe_color_block]
429
+ )
430
+
431
+ return demo
432
+
433
+ if __name__ == "__main__":
434
+ app_demo = build_interface()
435
+ print("Starting Multi-User Gradio App...")
436
+ app_demo.launch()