bebechien commited on
Commit
e6cb750
·
verified ·
1 Parent(s): 72b6692
Files changed (2) hide show
  1. src/model_trainer.py +9 -29
  2. src/ui.py +23 -63
src/model_trainer.py CHANGED
@@ -4,7 +4,7 @@ from datasets import Dataset
4
  from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
5
  from sentence_transformers.losses import MultipleNegativesRankingLoss
6
  from transformers import TrainerCallback, TrainingArguments
7
- from typing import List, Callable, Optional, Union
8
  from pathlib import Path
9
  from .config import AppConfig
10
 
@@ -57,22 +57,7 @@ def get_top_hits(
57
 
58
  return "\n".join(result)
59
 
60
- def get_available_namespaces(token: str) -> List[str]:
61
- """
62
- Returns a list of namespaces (user and organizations) the user can write to.
63
- First item is always the authenticated user's username.
64
- """
65
- try:
66
- api = HfApi(token=token)
67
- info = api.whoami()
68
- username = info['name']
69
- orgs = [org['name'] for org in info.get('orgs', [])]
70
- return [username] + orgs
71
- except Exception as e:
72
- print(f"Error fetching namespaces: {e}")
73
- return []
74
-
75
- def upload_model_to_hub(folder_path: Path, repo_name: str, token: str, entity: Optional[str] = None) -> str:
76
  """
77
  Uploads a local model folder to the Hugging Face Hub.
78
  Creates the repository if it doesn't exist.
@@ -80,16 +65,12 @@ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str, entity: O
80
  try:
81
  api = HfApi(token=token)
82
 
83
- # Determine the entity (namespace) to use
84
- if entity:
85
- namespace = entity
86
- else:
87
- # Fallback to the authenticated user's username
88
- user_info = api.whoami()
89
- namespace = user_info['name']
90
 
91
  # Construct the full repo ID
92
- repo_id = f"{namespace}/{repo_name}"
93
  print(f"Preparing to upload to: {repo_id}")
94
 
95
  # Create the repo (safe if it already exists)
@@ -107,9 +88,8 @@ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str, entity: O
107
  token=token
108
  )
109
  tags = info.card_data.tags
110
- if "embeddinggemma-tuning-lab" not in tags:
111
- tags.append("embeddinggemma-tuning-lab")
112
- metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token)
113
 
114
  return f"✅ Success! Model published at: {url}"
115
  except Exception as e:
@@ -189,4 +169,4 @@ def train_with_dataset(
189
  # Save the final fine-tuned model
190
  trainer.save_model()
191
 
192
- print(f"Model saved locally to: {output_dir}")
 
4
  from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
5
  from sentence_transformers.losses import MultipleNegativesRankingLoss
6
  from transformers import TrainerCallback, TrainingArguments
7
+ from typing import List, Callable, Optional
8
  from pathlib import Path
9
  from .config import AppConfig
10
 
 
57
 
58
  return "\n".join(result)
59
 
60
+ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  """
62
  Uploads a local model folder to the Hugging Face Hub.
63
  Creates the repository if it doesn't exist.
 
65
  try:
66
  api = HfApi(token=token)
67
 
68
+ # Get the authenticated user's username
69
+ user_info = api.whoami()
70
+ username = user_info['name']
 
 
 
 
71
 
72
  # Construct the full repo ID
73
+ repo_id = f"{username}/{repo_name}"
74
  print(f"Preparing to upload to: {repo_id}")
75
 
76
  # Create the repo (safe if it already exists)
 
88
  token=token
89
  )
90
  tags = info.card_data.tags
91
+ tags.append("embeddinggemma-tuning-lab")
92
+ metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token)
 
93
 
94
  return f"✅ Success! Model published at: {url}"
95
  except Exception as e:
 
169
  # Save the final fine-tuned model
170
  trainer.save_model()
171
 
172
+ print(f"Model saved locally to: {output_dir}")
src/ui.py CHANGED
@@ -4,7 +4,6 @@ from datetime import datetime
4
 
5
  from .config import AppConfig
6
  from .session_manager import HackerNewsFineTuner
7
- from .model_trainer import get_available_namespaces
8
 
9
  # --- Constants for Labels ---
10
  LABEL_FAV = "👍"
@@ -46,27 +45,13 @@ def on_app_load(app, profile: Optional[gr.OAuthProfile] = None):
46
  # Return 7 items: App state, Data updates (3), Hub updates (2), Username state (1)
47
  return app, stories, labels, text_update, repo_update, push_update, username
48
 
49
- def update_repo_preview(entity_name, repo_name):
50
- """Updates the markdown preview to show 'entity/repo_name'."""
51
- if not entity_name:
52
- return "⚠️ Please select a namespace (User or Org)."
53
 
54
  clean_repo = repo_name.strip() if repo_name else "..."
55
- return f"Target Repository: **`{entity_name}/{clean_repo}`**"
56
-
57
- def fetch_orgs_wrapper(oauth_token: Optional[gr.OAuthToken]):
58
- if not oauth_token:
59
- return gr.update(choices=[], value=None), "⚠️ Login required to fetch organizations."
60
-
61
- try:
62
- namespaces = get_available_namespaces(oauth_token.token)
63
- if not namespaces:
64
- return gr.update(choices=[], value=None), "❌ Failed to fetch namespaces."
65
-
66
- # Default to the first one (username)
67
- return gr.update(choices=namespaces, value=namespaces[0]), "✅ Organizations loaded."
68
- except Exception as e:
69
- return gr.update(choices=[], value=None), f"❌ Error: {str(e)}"
70
 
71
  def import_wrapper(app, file):
72
  return app.import_additional_dataset(file)
@@ -77,12 +62,11 @@ def export_wrapper(app):
77
  def download_model_wrapper(app):
78
  return app.download_model()
79
 
80
- def push_to_hub_wrapper(app, entity_name, repo_name, oauth_token: Optional[gr.OAuthToken]):
81
  if oauth_token is None:
82
  return "⚠️ You must be logged in to push to the Hub. Please sign in above."
83
  token_str = oauth_token.token
84
- # Pass the selected entity
85
- return app.upload_model(repo_name, token_str, entity=entity_name)
86
 
87
  def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
88
  """
@@ -142,7 +126,7 @@ def build_interface() -> gr.Blocks:
142
  with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
143
  gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
144
  with gr.Row():
145
- login_btn = gr.LoginButton(value="Sign in with Hugging Face")
146
  with gr.Column(scale=3):
147
  gr.Markdown("")
148
 
@@ -216,19 +200,11 @@ def build_interface() -> gr.Blocks:
216
  gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
217
 
218
  with gr.Row():
219
- # Entity (User/Org) Selection
220
- with gr.Column(scale=1):
221
- with gr.Row():
222
- entity_dropdown = gr.Dropdown(label="Owner / Organization", choices=[], interactive=True, scale=4)
223
- refresh_orgs_btn = gr.Button("🔄", scale=1, size="sm")
224
-
225
- # Repo Name
226
- with gr.Column(scale=2):
227
- repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
228
-
229
- push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
230
 
231
  repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
 
232
  push_status = gr.Markdown("")
233
 
234
  # --- Step 4: Downloads ---
@@ -267,24 +243,14 @@ def build_interface() -> gr.Blocks:
267
  inputs=[session_state],
268
  outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
269
  ).then(
270
- fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
271
- )
272
-
273
- # 2. Login Trigger -> Auto Fetch Orgs
274
- # ----------------
275
- # We can try to fetch orgs automatically if the token is available
276
-
277
- refresh_orgs_btn.click(
278
- fn=fetch_orgs_wrapper,
279
- inputs=[login_btn], # Gr.LoginButton acts as the OAuthToken input in this context? No, usually gr.OAuthToken is implicit or separate
280
- outputs=[entity_dropdown, push_status]
281
  ).then(
282
- fn=update_repo_preview,
283
- inputs=[entity_dropdown, repo_name_input],
284
- outputs=[repo_id_preview]
285
  )
286
 
287
- # 3. Reset / Refresh / Clear Selections
288
  # ----------------
289
  clear_reload_btn.click(
290
  fn=lambda: set_interactivity(False), outputs=action_buttons
@@ -313,7 +279,7 @@ def build_interface() -> gr.Blocks:
313
  outputs=[reset_counter, labels_state]
314
  )
315
 
316
- # 4. Import Data
317
  # ----------------
318
  import_file.change(
319
  fn=import_wrapper,
@@ -321,7 +287,7 @@ def build_interface() -> gr.Blocks:
321
  outputs=[download_status]
322
  )
323
 
324
- # 5. Run Training
325
  # ----------------
326
  run_training_btn.click(
327
  fn=lambda: set_interactivity(False), outputs=action_buttons
@@ -338,7 +304,7 @@ def build_interface() -> gr.Blocks:
338
  outputs=[repo_name_input, push_to_hub_btn]
339
  )
340
 
341
- # 6. Downloads
342
  # ----------------
343
  download_dataset_btn.click(
344
  fn=export_wrapper,
@@ -379,17 +345,11 @@ def build_interface() -> gr.Blocks:
379
  outputs=[repo_name_input, push_to_hub_btn]
380
  )
381
 
382
- # 7. Push to Hub
383
  # ----------------
384
- # Update preview on Name change or Entity change
385
  repo_name_input.change(
386
  fn=update_repo_preview,
387
- inputs=[entity_dropdown, repo_name_input],
388
- outputs=[repo_id_preview]
389
- )
390
- entity_dropdown.change(
391
- fn=update_repo_preview,
392
- inputs=[entity_dropdown, repo_name_input],
393
  outputs=[repo_id_preview]
394
  )
395
 
@@ -399,7 +359,7 @@ def build_interface() -> gr.Blocks:
399
  fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
400
  ).then(
401
  fn=push_to_hub_wrapper,
402
- inputs=[session_state, entity_dropdown, repo_name_input], # Pass entity dropdown
403
  outputs=[push_status]
404
  ).then(
405
  fn=lambda: set_interactivity(True), outputs=action_buttons
@@ -453,4 +413,4 @@ def build_interface() -> gr.Blocks:
453
  outputs=[vibe_score, vibe_status, style_thml, session_info_display]
454
  )
455
 
456
- return demo
 
4
 
5
  from .config import AppConfig
6
  from .session_manager import HackerNewsFineTuner
 
7
 
8
  # --- Constants for Labels ---
9
  LABEL_FAV = "👍"
 
45
  # Return 7 items: App state, Data updates (3), Hub updates (2), Username state (1)
46
  return app, stories, labels, text_update, repo_update, push_update, username
47
 
48
+ def update_repo_preview(username, repo_name):
49
+ """Updates the markdown preview to show 'username/repo_name'."""
50
+ if not username:
51
+ return "⚠️ Sign in to see the target repository path."
52
 
53
  clean_repo = repo_name.strip() if repo_name else "..."
54
+ return f"Target Repository: **`{username}/{clean_repo}`**"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def import_wrapper(app, file):
57
  return app.import_additional_dataset(file)
 
62
  def download_model_wrapper(app):
63
  return app.download_model()
64
 
65
+ def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
66
  if oauth_token is None:
67
  return "⚠️ You must be logged in to push to the Hub. Please sign in above."
68
  token_str = oauth_token.token
69
+ return app.upload_model(repo_name, token_str)
 
70
 
71
  def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
72
  """
 
126
  with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
127
  gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
128
  with gr.Row():
129
+ gr.LoginButton(value="Sign in with Hugging Face")
130
  with gr.Column(scale=3):
131
  gr.Markdown("")
132
 
 
200
  gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
201
 
202
  with gr.Row():
203
+ repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
204
+ push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
 
 
 
 
 
 
 
 
 
205
 
206
  repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
207
+
208
  push_status = gr.Markdown("")
209
 
210
  # --- Step 4: Downloads ---
 
243
  inputs=[session_state],
244
  outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
245
  ).then(
246
+ fn=update_repo_preview,
247
+ inputs=[username_state, repo_name_input],
248
+ outputs=[repo_id_preview]
 
 
 
 
 
 
 
 
249
  ).then(
250
+ fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
 
 
251
  )
252
 
253
+ # 2. Reset / Refresh / Clear Selections
254
  # ----------------
255
  clear_reload_btn.click(
256
  fn=lambda: set_interactivity(False), outputs=action_buttons
 
279
  outputs=[reset_counter, labels_state]
280
  )
281
 
282
+ # 3. Import Data
283
  # ----------------
284
  import_file.change(
285
  fn=import_wrapper,
 
287
  outputs=[download_status]
288
  )
289
 
290
+ # 4. Run Training
291
  # ----------------
292
  run_training_btn.click(
293
  fn=lambda: set_interactivity(False), outputs=action_buttons
 
304
  outputs=[repo_name_input, push_to_hub_btn]
305
  )
306
 
307
+ # 5. Downloads
308
  # ----------------
309
  download_dataset_btn.click(
310
  fn=export_wrapper,
 
345
  outputs=[repo_name_input, push_to_hub_btn]
346
  )
347
 
348
+ # 6. Push to Hub
349
  # ----------------
 
350
  repo_name_input.change(
351
  fn=update_repo_preview,
352
+ inputs=[username_state, repo_name_input],
 
 
 
 
 
353
  outputs=[repo_id_preview]
354
  )
355
 
 
359
  fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
360
  ).then(
361
  fn=push_to_hub_wrapper,
362
+ inputs=[session_state, repo_name_input],
363
  outputs=[push_status]
364
  ).then(
365
  fn=lambda: set_interactivity(True), outputs=action_buttons
 
413
  outputs=[vibe_score, vibe_status, style_thml, session_info_display]
414
  )
415
 
416
+ return demo