bebechien commited on
Commit
72b6692
·
verified ·
1 Parent(s): fc5f2ab

Fix hub ui

Browse files
Files changed (2) hide show
  1. src/model_trainer.py +29 -9
  2. src/ui.py +63 -23
src/model_trainer.py CHANGED
@@ -4,7 +4,7 @@ from datasets import Dataset
4
  from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
5
  from sentence_transformers.losses import MultipleNegativesRankingLoss
6
  from transformers import TrainerCallback, TrainingArguments
7
- from typing import List, Callable, Optional
8
  from pathlib import Path
9
  from .config import AppConfig
10
 
@@ -57,7 +57,22 @@ def get_top_hits(
57
 
58
  return "\n".join(result)
59
 
60
- def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  """
62
  Uploads a local model folder to the Hugging Face Hub.
63
  Creates the repository if it doesn't exist.
@@ -65,12 +80,16 @@ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
65
  try:
66
  api = HfApi(token=token)
67
 
68
- # Get the authenticated user's username
69
- user_info = api.whoami()
70
- username = user_info['name']
 
 
 
 
71
 
72
  # Construct the full repo ID
73
- repo_id = f"{username}/{repo_name}"
74
  print(f"Preparing to upload to: {repo_id}")
75
 
76
  # Create the repo (safe if it already exists)
@@ -88,8 +107,9 @@ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
88
  token=token
89
  )
90
  tags = info.card_data.tags
91
- tags.append("embeddinggemma-tuning-lab")
92
- metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token)
 
93
 
94
  return f"✅ Success! Model published at: {url}"
95
  except Exception as e:
@@ -169,4 +189,4 @@ def train_with_dataset(
169
  # Save the final fine-tuned model
170
  trainer.save_model()
171
 
172
- print(f"Model saved locally to: {output_dir}")
 
4
  from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
5
  from sentence_transformers.losses import MultipleNegativesRankingLoss
6
  from transformers import TrainerCallback, TrainingArguments
7
+ from typing import List, Callable, Optional, Union
8
  from pathlib import Path
9
  from .config import AppConfig
10
 
 
57
 
58
  return "\n".join(result)
59
 
60
+ def get_available_namespaces(token: str) -> List[str]:
61
+ """
62
+ Returns a list of namespaces (user and organizations) the user can write to.
63
+ First item is always the authenticated user's username.
64
+ """
65
+ try:
66
+ api = HfApi(token=token)
67
+ info = api.whoami()
68
+ username = info['name']
69
+ orgs = [org['name'] for org in info.get('orgs', [])]
70
+ return [username] + orgs
71
+ except Exception as e:
72
+ print(f"Error fetching namespaces: {e}")
73
+ return []
74
+
75
+ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str, entity: Optional[str] = None) -> str:
76
  """
77
  Uploads a local model folder to the Hugging Face Hub.
78
  Creates the repository if it doesn't exist.
 
80
  try:
81
  api = HfApi(token=token)
82
 
83
+ # Determine the entity (namespace) to use
84
+ if entity:
85
+ namespace = entity
86
+ else:
87
+ # Fallback to the authenticated user's username
88
+ user_info = api.whoami()
89
+ namespace = user_info['name']
90
 
91
  # Construct the full repo ID
92
+ repo_id = f"{namespace}/{repo_name}"
93
  print(f"Preparing to upload to: {repo_id}")
94
 
95
  # Create the repo (safe if it already exists)
 
107
  token=token
108
  )
109
  tags = info.card_data.tags
110
+ if "embeddinggemma-tuning-lab" not in tags:
111
+ tags.append("embeddinggemma-tuning-lab")
112
+ metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token)
113
 
114
  return f"✅ Success! Model published at: {url}"
115
  except Exception as e:
 
189
  # Save the final fine-tuned model
190
  trainer.save_model()
191
 
192
+ print(f"Model saved locally to: {output_dir}")
src/ui.py CHANGED
@@ -4,6 +4,7 @@ from datetime import datetime
4
 
5
  from .config import AppConfig
6
  from .session_manager import HackerNewsFineTuner
 
7
 
8
  # --- Constants for Labels ---
9
  LABEL_FAV = "👍"
@@ -45,13 +46,27 @@ def on_app_load(app, profile: Optional[gr.OAuthProfile] = None):
45
  # Return 7 items: App state, Data updates (3), Hub updates (2), Username state (1)
46
  return app, stories, labels, text_update, repo_update, push_update, username
47
 
48
- def update_repo_preview(username, repo_name):
49
- """Updates the markdown preview to show 'username/repo_name'."""
50
- if not username:
51
- return "⚠️ Sign in to see the target repository path."
52
 
53
  clean_repo = repo_name.strip() if repo_name else "..."
54
- return f"Target Repository: **`{username}/{clean_repo}`**"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  def import_wrapper(app, file):
57
  return app.import_additional_dataset(file)
@@ -62,11 +77,12 @@ def export_wrapper(app):
62
  def download_model_wrapper(app):
63
  return app.download_model()
64
 
65
- def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
66
  if oauth_token is None:
67
  return "⚠️ You must be logged in to push to the Hub. Please sign in above."
68
  token_str = oauth_token.token
69
- return app.upload_model(repo_name, token_str)
 
70
 
71
  def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
72
  """
@@ -126,7 +142,7 @@ def build_interface() -> gr.Blocks:
126
  with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
127
  gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
128
  with gr.Row():
129
- gr.LoginButton(value="Sign in with Hugging Face")
130
  with gr.Column(scale=3):
131
  gr.Markdown("")
132
 
@@ -200,11 +216,19 @@ def build_interface() -> gr.Blocks:
200
  gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
201
 
202
  with gr.Row():
203
- repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
204
- push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
 
 
 
 
 
 
 
205
 
206
- repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
207
 
 
208
  push_status = gr.Markdown("")
209
 
210
  # --- Step 4: Downloads ---
@@ -242,15 +266,25 @@ def build_interface() -> gr.Blocks:
242
  fn=on_app_load,
243
  inputs=[session_state],
244
  outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
245
- ).then(
246
- fn=update_repo_preview,
247
- inputs=[username_state, repo_name_input],
248
- outputs=[repo_id_preview]
249
  ).then(
250
  fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
251
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
252
 
253
- # 2. Reset / Refresh / Clear Selections
254
  # ----------------
255
  clear_reload_btn.click(
256
  fn=lambda: set_interactivity(False), outputs=action_buttons
@@ -279,7 +313,7 @@ def build_interface() -> gr.Blocks:
279
  outputs=[reset_counter, labels_state]
280
  )
281
 
282
- # 3. Import Data
283
  # ----------------
284
  import_file.change(
285
  fn=import_wrapper,
@@ -287,7 +321,7 @@ def build_interface() -> gr.Blocks:
287
  outputs=[download_status]
288
  )
289
 
290
- # 4. Run Training
291
  # ----------------
292
  run_training_btn.click(
293
  fn=lambda: set_interactivity(False), outputs=action_buttons
@@ -304,7 +338,7 @@ def build_interface() -> gr.Blocks:
304
  outputs=[repo_name_input, push_to_hub_btn]
305
  )
306
 
307
- # 5. Downloads
308
  # ----------------
309
  download_dataset_btn.click(
310
  fn=export_wrapper,
@@ -345,11 +379,17 @@ def build_interface() -> gr.Blocks:
345
  outputs=[repo_name_input, push_to_hub_btn]
346
  )
347
 
348
- # 6. Push to Hub
349
  # ----------------
 
350
  repo_name_input.change(
351
  fn=update_repo_preview,
352
- inputs=[username_state, repo_name_input],
 
 
 
 
 
353
  outputs=[repo_id_preview]
354
  )
355
 
@@ -359,7 +399,7 @@ def build_interface() -> gr.Blocks:
359
  fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
360
  ).then(
361
  fn=push_to_hub_wrapper,
362
- inputs=[session_state, repo_name_input],
363
  outputs=[push_status]
364
  ).then(
365
  fn=lambda: set_interactivity(True), outputs=action_buttons
@@ -413,4 +453,4 @@ def build_interface() -> gr.Blocks:
413
  outputs=[vibe_score, vibe_status, style_thml, session_info_display]
414
  )
415
 
416
- return demo
 
4
 
5
  from .config import AppConfig
6
  from .session_manager import HackerNewsFineTuner
7
+ from .model_trainer import get_available_namespaces
8
 
9
  # --- Constants for Labels ---
10
  LABEL_FAV = "👍"
 
46
  # Return 7 items: App state, Data updates (3), Hub updates (2), Username state (1)
47
  return app, stories, labels, text_update, repo_update, push_update, username
48
 
49
+ def update_repo_preview(entity_name, repo_name):
50
+ """Updates the markdown preview to show 'entity/repo_name'."""
51
+ if not entity_name:
52
+ return "⚠️ Please select a namespace (User or Org)."
53
 
54
  clean_repo = repo_name.strip() if repo_name else "..."
55
+ return f"Target Repository: **`{entity_name}/{clean_repo}`**"
56
+
57
+ def fetch_orgs_wrapper(oauth_token: Optional[gr.OAuthToken]):
58
+ if not oauth_token:
59
+ return gr.update(choices=[], value=None), "⚠️ Login required to fetch organizations."
60
+
61
+ try:
62
+ namespaces = get_available_namespaces(oauth_token.token)
63
+ if not namespaces:
64
+ return gr.update(choices=[], value=None), "❌ Failed to fetch namespaces."
65
+
66
+ # Default to the first one (username)
67
+ return gr.update(choices=namespaces, value=namespaces[0]), "✅ Organizations loaded."
68
+ except Exception as e:
69
+ return gr.update(choices=[], value=None), f"❌ Error: {str(e)}"
70
 
71
  def import_wrapper(app, file):
72
  return app.import_additional_dataset(file)
 
77
  def download_model_wrapper(app):
78
  return app.download_model()
79
 
80
+ def push_to_hub_wrapper(app, entity_name, repo_name, oauth_token: Optional[gr.OAuthToken]):
81
  if oauth_token is None:
82
  return "⚠️ You must be logged in to push to the Hub. Please sign in above."
83
  token_str = oauth_token.token
84
+ # Pass the selected entity
85
+ return app.upload_model(repo_name, token_str, entity=entity_name)
86
 
87
  def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
88
  """
 
142
  with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
143
  gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
144
  with gr.Row():
145
+ login_btn = gr.LoginButton(value="Sign in with Hugging Face")
146
  with gr.Column(scale=3):
147
  gr.Markdown("")
148
 
 
216
  gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
217
 
218
  with gr.Row():
219
+ # Entity (User/Org) Selection
220
+ with gr.Column(scale=1):
221
+ with gr.Row():
222
+ entity_dropdown = gr.Dropdown(label="Owner / Organization", choices=[], interactive=True, scale=4)
223
+ refresh_orgs_btn = gr.Button("🔄", scale=1, size="sm")
224
+
225
+ # Repo Name
226
+ with gr.Column(scale=2):
227
+ repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
228
 
229
+ push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
230
 
231
+ repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
232
  push_status = gr.Markdown("")
233
 
234
  # --- Step 4: Downloads ---
 
266
  fn=on_app_load,
267
  inputs=[session_state],
268
  outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
 
 
 
 
269
  ).then(
270
  fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
271
  )
272
+
273
+ # 2. Login Trigger -> Auto Fetch Orgs
274
+ # ----------------
275
+ # We can try to fetch orgs automatically if the token is available
276
+
277
+ refresh_orgs_btn.click(
278
+ fn=fetch_orgs_wrapper,
279
+ inputs=[login_btn], # Gr.LoginButton acts as the OAuthToken input in this context? No, usually gr.OAuthToken is implicit or separate
280
+ outputs=[entity_dropdown, push_status]
281
+ ).then(
282
+ fn=update_repo_preview,
283
+ inputs=[entity_dropdown, repo_name_input],
284
+ outputs=[repo_id_preview]
285
+ )
286
 
287
+ # 3. Reset / Refresh / Clear Selections
288
  # ----------------
289
  clear_reload_btn.click(
290
  fn=lambda: set_interactivity(False), outputs=action_buttons
 
313
  outputs=[reset_counter, labels_state]
314
  )
315
 
316
+ # 4. Import Data
317
  # ----------------
318
  import_file.change(
319
  fn=import_wrapper,
 
321
  outputs=[download_status]
322
  )
323
 
324
+ # 5. Run Training
325
  # ----------------
326
  run_training_btn.click(
327
  fn=lambda: set_interactivity(False), outputs=action_buttons
 
338
  outputs=[repo_name_input, push_to_hub_btn]
339
  )
340
 
341
+ # 6. Downloads
342
  # ----------------
343
  download_dataset_btn.click(
344
  fn=export_wrapper,
 
379
  outputs=[repo_name_input, push_to_hub_btn]
380
  )
381
 
382
+ # 7. Push to Hub
383
  # ----------------
384
+ # Update preview on Name change or Entity change
385
  repo_name_input.change(
386
  fn=update_repo_preview,
387
+ inputs=[entity_dropdown, repo_name_input],
388
+ outputs=[repo_id_preview]
389
+ )
390
+ entity_dropdown.change(
391
+ fn=update_repo_preview,
392
+ inputs=[entity_dropdown, repo_name_input],
393
  outputs=[repo_id_preview]
394
  )
395
 
 
399
  fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
400
  ).then(
401
  fn=push_to_hub_wrapper,
402
+ inputs=[session_state, entity_dropdown, repo_name_input], # Pass entity dropdown
403
  outputs=[push_status]
404
  ).then(
405
  fn=lambda: set_interactivity(True), outputs=action_buttons
 
453
  outputs=[vibe_score, vibe_status, style_thml, session_info_display]
454
  )
455
 
456
+ return demo