bebechien commited on
Commit
ad95ef1
·
verified ·
1 Parent(s): dad587a

Upload folder using huggingface_hub

Browse files
Files changed (4) hide show
  1. example_training_dataset.csv +4 -0
  2. src/config.py +0 -3
  3. src/session_manager.py +70 -55
  4. src/ui.py +282 -108
example_training_dataset.csv ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Anchor,Positive,Negative
2
+ MY_FAVORITE_NEWS,Denial of service and source code exposure in React Server Components,An SVG is all you need
3
+ MY_FAVORITE_NEWS,The highest quality codebase,Litestream VFS
4
+
src/config.py CHANGED
@@ -45,9 +45,6 @@ class AppConfig:
45
  # Anchor text used for contrastive learning dataset generation
46
  QUERY_ANCHOR: Final[str] = "MY_FAVORITE_NEWS"
47
 
48
- # Number of titles shown for user selection in the Gradio interface
49
- TOP_TITLES_COUNT: Final[int] = 10
50
-
51
  # Default export path for the dataset CSV
52
  DATASET_EXPORT_FILENAME: Final[Path] = ARTIFACTS_DIR.joinpath("training_dataset.csv")
53
 
 
45
  # Anchor text used for contrastive learning dataset generation
46
  QUERY_ANCHOR: Final[str] = "MY_FAVORITE_NEWS"
47
 
 
 
 
48
  # Default export path for the dataset CSV
49
  DATASET_EXPORT_FILENAME: Final[Path] = ARTIFACTS_DIR.joinpath("training_dataset.csv")
50
 
src/session_manager.py CHANGED
@@ -45,8 +45,6 @@ class HackerNewsFineTuner:
45
  self.model: Optional[SentenceTransformer] = None
46
  self.vibe_checker: Optional[VibeChecker] = None
47
  self.titles: List[str] = []
48
- self.target_titles: List[str] = []
49
- self.number_list: List[int] = []
50
  self.last_hn_dataset: List[List[str]] = []
51
  self.imported_dataset: List[List[str]] = []
52
 
@@ -66,9 +64,12 @@ class HackerNewsFineTuner:
66
 
67
  ## Data and Model Management ##
68
 
69
- def refresh_data_and_model(self) -> Tuple[gr.update, gr.update]:
70
  """
71
  Reloads model and fetches data.
 
 
 
72
  """
73
  print(f"[{self.session_id}] Reloading model and data...")
74
 
@@ -84,34 +85,23 @@ class HackerNewsFineTuner:
84
  print(error_msg)
85
  self.model = None
86
  self._update_vibe_checker()
87
- return (
88
- gr.update(choices=[], label="Model Load Failed"),
89
- gr.update(value=error_msg)
90
- )
91
 
92
  # 2. Fetch fresh news data
93
  news_feed, status_msg = read_hacker_news_rss(self.config)
94
- titles_out, target_titles_out = [], []
95
  status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
96
 
97
  if news_feed is not None and news_feed.entries:
98
- titles_out = [item.title for item in news_feed.entries[:self.config.TOP_TITLES_COUNT]]
99
- target_titles_out = [item.title for item in news_feed.entries[self.config.TOP_TITLES_COUNT:]]
100
  else:
101
- titles_out = ["Error fetching news.", "Check console."]
102
  gr.Warning(f"Data reload failed. {status_msg}")
103
 
104
  self.titles = titles_out
105
- self.target_titles = target_titles_out
106
- self.number_list = list(range(len(self.titles)))
107
-
108
- return (
109
- gr.update(
110
- choices=self.titles,
111
- label=f"Hacker News Top {len(self.titles)} (Select your favorites)"
112
- ),
113
- gr.update(value=status_value)
114
- )
115
 
116
  # --- Import Dataset/Export ---
117
  def import_additional_dataset(self, file_path: str) -> str:
@@ -123,6 +113,7 @@ class HackerNewsFineTuner:
123
  reader = csv.reader(f)
124
  try:
125
  header = next(reader)
 
126
  if not (header and header[0].lower().strip() == 'anchor'):
127
  f.seek(0)
128
  except StopIteration:
@@ -188,54 +179,75 @@ class HackerNewsFineTuner:
188
 
189
 
190
  ## Training Logic ##
191
- def _create_hn_dataset(self, selected_ids: List[int]) -> Tuple[List[List[str]], str, str]:
192
- total_ids, selected_ids = set(self.number_list), set(selected_ids)
193
- non_selected_ids = total_ids - selected_ids
194
- is_minority = len(selected_ids) < (len(total_ids) / 2)
195
-
196
- anchor_ids, pool_ids = (non_selected_ids, list(selected_ids)) if is_minority else (selected_ids, list(non_selected_ids))
197
-
198
- def get_titles(anchor_id, pool_id):
199
- return (self.titles[pool_id], self.titles[anchor_id]) if is_minority else (self.titles[anchor_id], self.titles[pool_id])
200
 
201
- if not pool_ids or not anchor_ids:
202
- return [], "", ""
 
203
 
204
- fav_idx = pool_ids[0] if is_minority else list(anchor_ids)[0]
205
- non_fav_idx = list(anchor_ids)[0] if is_minority else pool_ids[0]
206
 
207
- hn_dataset = []
208
- pool_cycler = cycle(pool_ids)
209
- for anchor_id in sorted(list(anchor_ids)):
210
- fav, non_fav = get_titles(anchor_id, next(pool_cycler))
211
- hn_dataset.append([self.config.QUERY_ANCHOR, fav, non_fav])
 
 
 
 
 
 
 
 
 
212
 
213
- return hn_dataset, self.titles[fav_idx], self.titles[non_fav_idx]
214
 
215
- def training(self, selected_ids: List[int]) -> str:
 
 
 
 
 
 
216
  if self.model is None:
217
  raise gr.Error("Model not loaded.")
218
- if not selected_ids:
219
- raise gr.Error("Select at least one title.")
220
- if len(selected_ids) == len(self.number_list):
221
- raise gr.Error("Cannot select all titles.")
222
-
223
- hn_dataset, _, _ = self._create_hn_dataset(selected_ids)
224
- self.last_hn_dataset = hn_dataset
225
- final_dataset = self.last_hn_dataset + self.imported_dataset
226
 
227
- if not final_dataset:
228
- raise gr.Error("Dataset is empty.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
  def semantic_search_fn() -> str:
231
- return get_top_hits(model=self.model, target_titles=self.target_titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
232
 
233
  result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
234
- print(f"[{self.session_id}] Starting Training...")
235
 
236
  train_with_dataset(
237
  model=self.model,
238
- dataset=final_dataset,
239
  output_dir=self.output_dir,
240
  task_name=self.config.TASK_NAME,
241
  search_fn=semantic_search_fn
@@ -247,6 +259,9 @@ class HackerNewsFineTuner:
247
  result += "### Search (After):\n" + f"{semantic_search_fn()}"
248
  return result
249
 
 
 
 
250
  ## Vibe Check Logic ##
251
  def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
252
  info_text = f"**Session:** {self.session_id[:6]}<br>**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}"
@@ -303,4 +318,4 @@ class HackerNewsFineTuner:
303
  f"| [{item['title']}]({item['link']}) "
304
  f"| [Comments]({item['comments']}) "
305
  f"| {item['published']} |\n")
306
- return md
 
45
  self.model: Optional[SentenceTransformer] = None
46
  self.vibe_checker: Optional[VibeChecker] = None
47
  self.titles: List[str] = []
 
 
48
  self.last_hn_dataset: List[List[str]] = []
49
  self.imported_dataset: List[List[str]] = []
50
 
 
64
 
65
  ## Data and Model Management ##
66
 
67
+ def refresh_data_and_model(self) -> Tuple[List[str], str]:
68
  """
69
  Reloads model and fetches data.
70
+ Returns:
71
+ - List of titles (for the UI)
72
+ - Status message string
73
  """
74
  print(f"[{self.session_id}] Reloading model and data...")
75
 
 
85
  print(error_msg)
86
  self.model = None
87
  self._update_vibe_checker()
88
+ return [], error_msg
 
 
 
89
 
90
  # 2. Fetch fresh news data
91
  news_feed, status_msg = read_hacker_news_rss(self.config)
92
+ titles_out = []
93
  status_value: str = f"Ready. Session ID: {self.session_id[:8]}... | Status: {status_msg}"
94
 
95
  if news_feed is not None and news_feed.entries:
96
+ titles_out = [item.title for item in news_feed.entries]
 
97
  else:
98
+ titles_out = ["Error fetching news."]
99
  gr.Warning(f"Data reload failed. {status_msg}")
100
 
101
  self.titles = titles_out
102
+
103
+ # Return raw list of titles + status text
104
+ return self.titles, status_value
 
 
 
 
 
 
 
105
 
106
  # --- Import Dataset/Export ---
107
  def import_additional_dataset(self, file_path: str) -> str:
 
113
  reader = csv.reader(f)
114
  try:
115
  header = next(reader)
116
+ # Simple heuristic to detect if header exists
117
  if not (header and header[0].lower().strip() == 'anchor'):
118
  f.seek(0)
119
  except StopIteration:
 
179
 
180
 
181
  ## Training Logic ##
182
+ def _create_hn_dataset(self, pos_ids: List[int], neg_ids: List[int]) -> List[List[str]]:
183
+ """
184
+ Creates triplets (Anchor, Positive, Negative) from the selected indices.
185
+ Uses cycling to balance the dataset if the number of positives != negatives.
186
+ """
187
+ if not pos_ids or not neg_ids:
188
+ return []
 
 
189
 
190
+ # Convert indices to actual title strings
191
+ pos_titles = [self.titles[i] for i in pos_ids]
192
+ neg_titles = [self.titles[i] for i in neg_ids]
193
 
194
+ dataset = []
 
195
 
196
+ # We need to pair every Positive with a Negative.
197
+ # Strategy: Iterate over the longer list and cycle through the shorter list
198
+ # to ensure every selected item is used at least once and the dataset is balanced.
199
+
200
+ if len(pos_titles) >= len(neg_titles):
201
+ # More positives than negatives: Iterate positives, reuse negatives
202
+ neg_cycle = cycle(neg_titles)
203
+ for p_title in pos_titles:
204
+ dataset.append([self.config.QUERY_ANCHOR, p_title, next(neg_cycle)])
205
+ else:
206
+ # More negatives than positives: Iterate negatives, reuse positives
207
+ pos_cycle = cycle(pos_titles)
208
+ for n_title in neg_titles:
209
+ dataset.append([self.config.QUERY_ANCHOR, next(pos_cycle), n_title])
210
 
211
+ return dataset
212
 
213
+ def training(self, pos_ids: List[int], neg_ids: List[int]) -> str:
214
+ """
215
+ Main training entry point.
216
+ Args:
217
+ pos_ids: Indices of stories marked as "Favorite"
218
+ neg_ids: Indices of stories marked as "Dislike"
219
+ """
220
  if self.model is None:
221
  raise gr.Error("Model not loaded.")
 
 
 
 
 
 
 
 
222
 
223
+ # Validation
224
+ if not pos_ids:
225
+ raise gr.Error("Please select at least one 'Favorite' story.")
226
+ if not neg_ids:
227
+ raise gr.Error("Please select at least one 'Dislike' story.")
228
+
229
+ # Generate Dataset
230
+ hn_dataset = self._create_hn_dataset(pos_ids, neg_ids)
231
+
232
+ # Merge with imported dataset if it exists
233
+ if self.imported_dataset:
234
+ # If we have both, combine them
235
+ self.last_hn_dataset = hn_dataset + self.imported_dataset
236
+ else:
237
+ self.last_hn_dataset = hn_dataset
238
+
239
+ if not self.last_hn_dataset:
240
+ raise gr.Error("Dataset generation failed (Empty dataset).")
241
 
242
  def semantic_search_fn() -> str:
243
+ return get_top_hits(model=self.model, target_titles=self.titles, task_name=self.config.TASK_NAME, query=self.config.QUERY_ANCHOR)
244
 
245
  result = "### Search (Before):\n" + f"{semantic_search_fn()}\n\n"
246
+ print(f"[{self.session_id}] Starting Training with {len(self.last_hn_dataset)} examples...")
247
 
248
  train_with_dataset(
249
  model=self.model,
250
+ dataset=self.last_hn_dataset,
251
  output_dir=self.output_dir,
252
  task_name=self.config.TASK_NAME,
253
  search_fn=semantic_search_fn
 
259
  result += "### Search (After):\n" + f"{semantic_search_fn()}"
260
  return result
261
 
262
+ def is_model_tuned(self) -> bool:
263
+ return True if self.last_hn_dataset else False
264
+
265
  ## Vibe Check Logic ##
266
  def get_vibe_check(self, news_text: str) -> Tuple[str, str, gr.update]:
267
  info_text = f"**Session:** {self.session_id[:6]}<br>**Model:** `{self.config.MODEL_NAME}`{' - Fine-tuned' if self.last_hn_dataset else ''}"
 
318
  f"| [{item['title']}]({item['link']}) "
319
  f"| [Comments]({item['comments']}) "
320
  f"| {item['published']} |\n")
321
+ return md
src/ui.py CHANGED
@@ -1,26 +1,64 @@
1
  import gradio as gr
2
- from typing import Optional
3
  from datetime import datetime
4
 
5
  from .config import AppConfig
6
  from .session_manager import HackerNewsFineTuner
7
 
 
 
 
 
 
8
  # --- Session Wrappers ---
9
 
10
  def refresh_wrapper(app):
11
  """
12
  Initializes the session if it's not already created, then runs the refresh.
13
- Returns the app instance to update the State.
 
 
 
 
14
  """
15
  if app is None or callable(app) or isinstance(app, type):
16
  print("Initializing new HackerNewsFineTuner session...")
17
  app = HackerNewsFineTuner(AppConfig)
18
 
19
  # Run the refresh logic
20
- update1, update2 = app.refresh_data_and_model()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Return 3 items: The App Instance (for State), Choice Update, Text Update
23
- return app, update1, update2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def import_wrapper(app, file):
26
  return app.import_additional_dataset(file)
@@ -32,19 +70,31 @@ def download_model_wrapper(app):
32
  return app.download_model()
33
 
34
  def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
35
- """
36
- Wrapper for pushing the model to the Hugging Face Hub.
37
- Gradio automatically injects 'oauth_token' if the user is logged in via LoginButton.
38
- """
39
  if oauth_token is None:
40
  return "⚠️ You must be logged in to push to the Hub. Please sign in above."
41
-
42
- # Extract the token string from the OAuthToken object
43
  token_str = oauth_token.token
44
  return app.upload_model(repo_name, token_str)
45
 
46
- def training_wrapper(app, selected_ids):
47
- return app.training(selected_ids)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
  def vibe_check_wrapper(app, text):
50
  return app.get_vibe_check(text)
@@ -57,124 +107,248 @@ def mood_feed_wrapper(app):
57
 
58
  def build_interface() -> gr.Blocks:
59
  with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
60
- # Initialize state as None. It will be populated by refresh_wrapper on load.
61
  session_state = gr.State()
 
 
 
 
 
 
62
 
63
  with gr.Column():
64
  gr.Markdown("# 🤖 EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
65
  gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
66
- gr.LoginButton(value="(Optional) Sign in to Hugging Face, if you want to push fine-tuned model to your repo.")
67
 
68
  with gr.Tab("🚀 Fine-Tuning & Evaluation"):
69
- with gr.Column():
70
- gr.Markdown("## Fine-Tuning & Semantic Search\nSelect titles to fine-tune the model towards making them more similar to **`MY_FAVORITE_NEWS`**.")
71
- with gr.Row():
72
- favorite_list = gr.CheckboxGroup(choices=[], type="index", label="Hacker News Top Stories", show_select_all=True)
73
- output = gr.Textbox(lines=14, label="Training and Search Results", value="Loading data...")
74
-
 
 
 
 
 
 
75
  with gr.Row():
76
- clear_reload_btn = gr.Button("Clear & Reload")
77
- run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary")
 
 
 
 
 
78
 
79
- gr.Markdown("--- \n ## Dataset & Model Management")
80
- gr.Markdown("To train on your own data, upload a CSV file with the following columns (no header required, or header ignored if present):\n1. **Anchor**: A fixed anchor phrase, `MY_FAVORITE_NEWS`.\n2. **Positive**: A title or contents that you like.\n3. **Negative**: A title or contents that you don't like.\n\nExample CSV Row:\n```\nMY_FAVORITE_NEWS,What is machine learning?,How to write a compiler from scratch.\n```")
81
- import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=50)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
  with gr.Row():
84
- download_dataset_btn = gr.Button("💾 Export Dataset")
85
- download_model_btn = gr.Button("⬇️ Download Fine-Tuned Model")
86
-
87
- download_status = gr.Markdown("Ready.")
88
 
89
- with gr.Row():
90
- dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
91
- model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
92
 
93
- gr.Markdown("### ☁️ Publish to Hugging Face Hub")
 
 
 
94
  with gr.Row():
95
- repo_name_input = gr.Textbox(label="Target Repository Name", placeholder="e.g., my-news-vibe-model")
96
- push_to_hub_btn = gr.Button("Push to Hub", variant="secondary")
 
 
97
 
98
  push_status = gr.Markdown("")
99
 
100
- # --- Interactions ---
101
-
102
- # 1. Initial Load: Initialize State and Load Data
103
- demo.load(
104
- fn=refresh_wrapper,
105
- inputs=[session_state],
106
- outputs=[session_state, favorite_list, output]
107
- )
108
 
109
- buttons_to_lock = [
110
- clear_reload_btn,
111
- run_training_btn,
112
- download_dataset_btn,
113
- download_model_btn,
114
- push_to_hub_btn
115
- ]
116
-
117
- # 2. Buttons
118
- clear_reload_btn.click(
119
- fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
120
- outputs=buttons_to_lock
121
- ).then(
122
- fn=refresh_wrapper,
123
- inputs=[session_state],
124
- outputs=[session_state, favorite_list, output]
125
- ).then(
126
- fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
127
- outputs=buttons_to_lock
128
- )
129
 
130
- run_training_btn.click(
131
- fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
132
- outputs=buttons_to_lock
133
- ).then(
134
- fn=training_wrapper,
135
- inputs=[session_state, favorite_list],
136
- outputs=[output]
137
- ).then(
138
- fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
139
- outputs=buttons_to_lock
140
- )
141
 
142
- import_file.change(
143
- fn=import_wrapper,
144
- inputs=[session_state, import_file],
145
- outputs=[download_status]
146
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
147
 
148
- download_dataset_btn.click(
149
- fn=export_wrapper,
150
- inputs=[session_state],
151
- outputs=[dataset_output]
152
- ).then(
153
- lambda p: gr.update(visible=True) if p else gr.update(), inputs=[dataset_output], outputs=[dataset_output]
154
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
- download_model_btn.click(
157
- fn=lambda: [gr.update(interactive=False)]*len(buttons_to_lock),
158
- outputs=buttons_to_lock
159
- ).then(
160
- lambda: [gr.update(value=None, visible=False), "Zipping..."], None, [model_output, download_status], queue=False
161
- ).then(
162
- fn=download_model_wrapper,
163
- inputs=[session_state],
164
- outputs=[model_output]
165
- ).then(
166
- lambda p: [gr.update(visible=p is not None, value=p), "ZIP ready." if p else "Zipping failed."], [model_output], [model_output, download_status]
167
- ).then(
168
- fn=lambda: [gr.update(interactive=True)]*len(buttons_to_lock),
169
- outputs=buttons_to_lock
170
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- # Push to Hub Interaction
173
- push_to_hub_btn.click(
174
- fn=push_to_hub_wrapper,
175
- inputs=[session_state, repo_name_input],
176
- outputs=[push_status]
177
- )
178
 
179
  with gr.Tab("📰 Hacker News Similarity Check"):
180
  with gr.Column():
 
1
  import gradio as gr
2
+ from typing import Optional, Dict, List
3
  from datetime import datetime
4
 
5
  from .config import AppConfig
6
  from .session_manager import HackerNewsFineTuner
7
 
8
+ # --- Constants for Labels ---
9
+ LABEL_FAV = "👍"
10
+ LABEL_NEU = "😐"
11
+ LABEL_DIS = "👎"
12
+
13
  # --- Session Wrappers ---
14
 
15
  def refresh_wrapper(app):
16
  """
17
  Initializes the session if it's not already created, then runs the refresh.
18
+ Returns:
19
+ 1. App instance
20
+ 2. Stories List (List[str])
21
+ 3. Empty Labels Dict (Dict)
22
+ 4. Log text
23
  """
24
  if app is None or callable(app) or isinstance(app, type):
25
  print("Initializing new HackerNewsFineTuner session...")
26
  app = HackerNewsFineTuner(AppConfig)
27
 
28
  # Run the refresh logic
29
+ # choices_list is a simple list of strings: ["Title 1", "Title 2", ...]
30
+ choices_list, log_update = app.refresh_data_and_model()
31
+
32
+ # Reset user labels
33
+ empty_labels = {}
34
+
35
+ return app, choices_list, empty_labels, log_update
36
+
37
+ def on_app_load(app, profile: Optional[gr.OAuthProfile] = None):
38
+ """
39
+ Combined wrapper for initial load:
40
+ 1. Initializes/Refreshes App Session
41
+ 2. Checks OAuth Profile to enable/disable Hub features
42
+ """
43
+ # 1. Reuse the logic from refresh_wrapper
44
+ app, stories, labels, text_update = refresh_wrapper(app)
45
 
46
+ # 2. Check Login Status
47
+ is_logged_in = profile is not None
48
+ username = profile.username if is_logged_in else None
49
+
50
+ hub_interactive = gr.update(interactive=is_logged_in)
51
+
52
+ # Return items matching the output signature of demo.load
53
+ return app, stories, labels, text_update, hub_interactive, hub_interactive, username
54
+
55
+ def update_repo_preview(username, repo_name):
56
+ """Updates the markdown preview to show 'username/repo_name'."""
57
+ if not username:
58
+ return "⚠️ Sign in to see the target repository path."
59
+
60
+ clean_repo = repo_name.strip() if repo_name else "..."
61
+ return f"Target Repository: **`{username}/{clean_repo}`**"
62
 
63
  def import_wrapper(app, file):
64
  return app.import_additional_dataset(file)
 
70
  return app.download_model()
71
 
72
  def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
 
 
 
 
73
  if oauth_token is None:
74
  return "⚠️ You must be logged in to push to the Hub. Please sign in above."
 
 
75
  token_str = oauth_token.token
76
  return app.upload_model(repo_name, token_str)
77
 
78
+ def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
79
+ """
80
+ Parses the Stories and Labels to extract Positive and Negative indices.
81
+ stories: List of titles
82
+ labels: Dictionary of {index: LABEL_FAV | LABEL_DIS | LABEL_NEU}
83
+ """
84
+ pos_ids = []
85
+ neg_ids = []
86
+
87
+ # Iterate through all available stories by index
88
+ for i in range(len(stories)):
89
+ # Get label for this index, default to Neutral if not set
90
+ label = labels.get(i, LABEL_NEU)
91
+
92
+ if label == LABEL_FAV:
93
+ pos_ids.append(i)
94
+ elif label == LABEL_DIS:
95
+ neg_ids.append(i)
96
+
97
+ return app.training(pos_ids, neg_ids)
98
 
99
  def vibe_check_wrapper(app, text):
100
  return app.get_vibe_check(text)
 
107
 
108
  def build_interface() -> gr.Blocks:
109
  with gr.Blocks(title="EmbeddingGemma Modkit") as demo:
 
110
  session_state = gr.State()
111
+ username_state = gr.State()
112
+
113
+ # State variables for the Feed List and User Choices
114
+ stories_state = gr.State([])
115
+ labels_state = gr.State({})
116
+ reset_counter = gr.State(0)
117
 
118
  with gr.Column():
119
  gr.Markdown("# 🤖 EmbeddingGemma Modkit: Fine-Tuning and Mood Reader")
120
  gr.Markdown("This project provides a set of tools to fine-tune [EmbeddingGemma](https://huggingface.co/google/embeddinggemma-300m) to understand your personal taste in Hacker News titles and then use it to score and rank new articles based on their \"vibe\". The core idea is to measure the \"vibe\" of a news title by calculating the semantic similarity between its embedding and the embedding of a fixed anchor phrase, **`MY_FAVORITE_NEWS`**.<br>See [README](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/README.md) for more details.")
 
121
 
122
  with gr.Tab("🚀 Fine-Tuning & Evaluation"):
123
+
124
+ # --- Model Indicator ---
125
+ gr.Dropdown(
126
+ choices=[f"{AppConfig.MODEL_NAME}"],
127
+ value=f"{AppConfig.MODEL_NAME}",
128
+ label="Base Model for Fine-tuning",
129
+ interactive=False
130
+ )
131
+
132
+ # --- Step 0: Login ---
133
+ with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
134
+ gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
135
  with gr.Row():
136
+ gr.LoginButton(value="Sign in with Hugging Face")
137
+ with gr.Column(scale=3):
138
+ gr.Markdown("")
139
+
140
+ # --- Step 1: Data Selection ---
141
+ with gr.Accordion("1️⃣ Step 1: Select Data Source", open=True):
142
+ gr.Markdown("Select titles from the live Hacker News feed **OR** upload your own CSV dataset to prepare your training data.")
143
 
144
+ with gr.Column():
145
+ # Option A: Live Feed (Radio List)
146
+ with gr.Accordion("Option A: Live Hacker News Feed", open=True):
147
+ gr.Markdown("Rate the stories below to define your vibe.\n\n**⚠️ Note: You must select at least one Favorite and one Dislike to run training.**")
148
+
149
+ with gr.Row():
150
+ reset_all_btn = gr.Button("Reset Selection ↺", variant="secondary", scale=1)
151
+ with gr.Column(scale=3):
152
+ gr.Markdown("")
153
+
154
+ # Dynamic rendering of the story list
155
+ @gr.render(inputs=[stories_state, reset_counter])
156
+ def render_story_list(stories, _counter):
157
+ if not stories:
158
+ gr.Markdown("*No stories loaded. Click 'Reset Model & Fine-tuning state' to fetch data.*")
159
+ return
160
+
161
+ for i, title in enumerate(stories):
162
+ with gr.Row(variant="compact", elem_id=f"story_row_{i}"):
163
+ # Title
164
+ with gr.Column(scale=3):
165
+ gr.Markdown(f"**{i+1}.** {title}")
166
+
167
+ # Radio Selection
168
+ radio = gr.Radio(
169
+ choices=[LABEL_FAV, LABEL_NEU, LABEL_DIS],
170
+ value=LABEL_NEU,
171
+ show_label=False,
172
+ container=False,
173
+ min_width=80,
174
+ scale=1,
175
+ interactive=True
176
+ )
177
+
178
+ # Update logic
179
+ def update_label(new_val, current_labels, idx=i):
180
+ current_labels[idx] = new_val
181
+ return current_labels
182
+
183
+ radio.change(
184
+ fn=update_label,
185
+ inputs=[radio, labels_state],
186
+ outputs=[labels_state]
187
+ )
188
+
189
+ # Option B: Upload
190
+ with gr.Accordion("Option B: Upload Custom Dataset", open=False):
191
+ gr.Markdown("Upload a CSV file with columns (no header required, or header ignored if present): `Anchor`, `Positive`, `Negative`.")
192
+ gr.Markdown("See also: [example_training.dataset.csv](https://huggingface.co/spaces/google/embeddinggemma-modkit/blob/main/example_training.dataset.csv)<br>Example:<br>`MY_FAVORITE_NEWS,Good Title,Bad Title`")
193
+ import_file = gr.File(label="Upload Additional Dataset (.csv)", file_types=[".csv"], height=100)
194
+
195
+ # --- Step 2: Training ---
196
+ with gr.Accordion("2️⃣ Step 2: Run Tuning", open=True):
197
+ gr.Markdown("Fine-tune the model using the data selected or uploaded above.")
198
 
199
  with gr.Row():
200
+ run_training_btn = gr.Button("🚀 Run Fine-Tuning", variant="primary", scale=1)
201
+ clear_reload_btn = gr.Button("Reset Model & Fine-tuning state", scale=1)
 
 
202
 
203
+ output = gr.Textbox(lines=10, label="Training Logs & Search Results", value="Waiting to start...", autoscroll=True)
 
 
204
 
205
+ # --- Step 3: Push to Hub ---
206
+ with gr.Accordion("3️⃣ Step 3: Save to Hugging Face Hub (Optional)", open=False):
207
+ gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
208
+
209
  with gr.Row():
210
+ repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
211
+ push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
212
+
213
+ repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
214
 
215
  push_status = gr.Markdown("")
216
 
217
+ # --- Step 4: Downloads ---
218
+ with gr.Accordion("4️⃣ Step 4: Download Artifacts", open=False):
219
+ gr.Markdown("Export your combined dataset or download the fine-tuned model locally.")
220
+
221
+ with gr.Row():
222
+ download_dataset_btn = gr.Button("💾 Export Dataset", interactive=False)
223
+ download_model_btn = gr.Button("⬇️ Download Model ZIP", interactive=False)
 
224
 
225
+ download_status = gr.Markdown("Ready.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
+ with gr.Row():
228
+ dataset_output = gr.File(label="Download Dataset CSV", height=50, visible=False, interactive=False)
229
+ model_output = gr.File(label="Download Model ZIP", height=50, visible=False, interactive=False)
 
 
 
 
 
 
 
 
230
 
231
+ # --- Interaction Logic ---
232
+
233
+ action_buttons = [
234
+ clear_reload_btn,
235
+ run_training_btn,
236
+ download_dataset_btn,
237
+ download_model_btn
238
+ ]
239
+
240
+ def set_interactivity(interactive: bool):
241
+ """Helper to lock/unlock all main action buttons."""
242
+ return [gr.update(interactive=interactive) for _ in action_buttons]
243
+
244
+ # 1. App Startup
245
+ # ----------------
246
+ demo.load(
247
+ fn=lambda: set_interactivity(False), outputs=action_buttons
248
+ ).then(
249
+ fn=on_app_load,
250
+ inputs=[session_state],
251
+ outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
252
+ ).then(
253
+ fn=update_repo_preview,
254
+ inputs=[username_state, repo_name_input],
255
+ outputs=[repo_id_preview]
256
+ ).then(
257
+ fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
258
+ )
259
+
260
+ # 2. Reset / Refresh / Clear Selections
261
+ # ----------------
262
+ clear_reload_btn.click(
263
+ fn=lambda: set_interactivity(False), outputs=action_buttons
264
+ ).then(
265
+ fn=refresh_wrapper,
266
+ inputs=[session_state],
267
+ outputs=[session_state, stories_state, labels_state, output]
268
+ ).then(
269
+ fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
270
+ )
271
+
272
+ # Reset Selection Button Logic
273
+ def reset_all_selections(counter):
274
+ # Returns: (incremented counter, empty dict for labels)
275
+ return counter + 1, {}
276
 
277
+ reset_all_btn.click(
278
+ fn=reset_all_selections,
279
+ inputs=[reset_counter],
280
+ outputs=[reset_counter, labels_state]
281
+ )
282
+
283
+ # 3. Import Data
284
+ # ----------------
285
+ import_file.change(
286
+ fn=import_wrapper,
287
+ inputs=[session_state, import_file],
288
+ outputs=[download_status]
289
+ )
290
+
291
+ # 4. Run Training
292
+ # ----------------
293
+ run_training_btn.click(
294
+ fn=lambda: set_interactivity(False), outputs=action_buttons
295
+ ).then(
296
+ fn=training_wrapper,
297
+ inputs=[session_state, stories_state, labels_state],
298
+ outputs=[output]
299
+ ).then(
300
+ # Unlock all buttons (including downloads now that we have a model)
301
+ fn=lambda: set_interactivity(True), outputs=action_buttons
302
+ )
303
+
304
+ # 5. Downloads
305
+ # ----------------
306
+ download_dataset_btn.click(
307
+ fn=export_wrapper,
308
+ inputs=[session_state],
309
+ outputs=[dataset_output]
310
+ ).then(
311
+ # Just show the file output if it exists
312
+ lambda p: gr.update(visible=True) if p else gr.update(),
313
+ inputs=[dataset_output],
314
+ outputs=[dataset_output]
315
+ )
316
 
317
+ download_model_btn.click(
318
+ # Lock UI
319
+ fn=lambda: set_interactivity(False), outputs=action_buttons
320
+ ).then(
321
+ # Reset previous outputs and show "Zipping..."
322
+ fn=lambda: [gr.update(value=None, visible=False), "⏳ Zipping model..."],
323
+ outputs=[model_output, download_status]
324
+ ).then(
325
+ # Generate Zip
326
+ fn=download_model_wrapper,
327
+ inputs=[session_state],
328
+ outputs=[model_output]
329
+ ).then(
330
+ # Update UI with result
331
+ fn=lambda p: [gr.update(visible=p is not None, value=p), "✅ ZIP ready." if p else "❌ Zipping failed."],
332
+ inputs=[model_output],
333
+ outputs=[model_output, download_status]
334
+ ).then(
335
+ # Unlock UI
336
+ fn=lambda: set_interactivity(True), outputs=action_buttons
337
+ )
338
+
339
+ # 6. Push to Hub
340
+ # ----------------
341
+ repo_name_input.change(
342
+ fn=update_repo_preview,
343
+ inputs=[username_state, repo_name_input],
344
+ outputs=[repo_id_preview]
345
+ )
346
 
347
+ push_to_hub_btn.click(
348
+ fn=push_to_hub_wrapper,
349
+ inputs=[session_state, repo_name_input],
350
+ outputs=[push_status]
351
+ )
 
352
 
353
  with gr.Tab("📰 Hacker News Similarity Check"):
354
  with gr.Column():