Fix hub ui
Browse files- src/model_trainer.py +29 -9
- src/ui.py +63 -23
src/model_trainer.py
CHANGED
|
@@ -4,7 +4,7 @@ from datasets import Dataset
|
|
| 4 |
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
| 5 |
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
| 6 |
from transformers import TrainerCallback, TrainingArguments
|
| 7 |
-
from typing import List, Callable, Optional
|
| 8 |
from pathlib import Path
|
| 9 |
from .config import AppConfig
|
| 10 |
|
|
@@ -57,7 +57,22 @@ def get_top_hits(
|
|
| 57 |
|
| 58 |
return "\n".join(result)
|
| 59 |
|
| 60 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
"""
|
| 62 |
Uploads a local model folder to the Hugging Face Hub.
|
| 63 |
Creates the repository if it doesn't exist.
|
|
@@ -65,12 +80,16 @@ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
|
|
| 65 |
try:
|
| 66 |
api = HfApi(token=token)
|
| 67 |
|
| 68 |
-
#
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Construct the full repo ID
|
| 73 |
-
repo_id = f"{
|
| 74 |
print(f"Preparing to upload to: {repo_id}")
|
| 75 |
|
| 76 |
# Create the repo (safe if it already exists)
|
|
@@ -88,8 +107,9 @@ def upload_model_to_hub(folder_path: Path, repo_name: str, token: str) -> str:
|
|
| 88 |
token=token
|
| 89 |
)
|
| 90 |
tags = info.card_data.tags
|
| 91 |
-
|
| 92 |
-
|
|
|
|
| 93 |
|
| 94 |
return f"✅ Success! Model published at: {url}"
|
| 95 |
except Exception as e:
|
|
@@ -169,4 +189,4 @@ def train_with_dataset(
|
|
| 169 |
# Save the final fine-tuned model
|
| 170 |
trainer.save_model()
|
| 171 |
|
| 172 |
-
print(f"Model saved locally to: {output_dir}")
|
|
|
|
| 4 |
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments
|
| 5 |
from sentence_transformers.losses import MultipleNegativesRankingLoss
|
| 6 |
from transformers import TrainerCallback, TrainingArguments
|
| 7 |
+
from typing import List, Callable, Optional, Union
|
| 8 |
from pathlib import Path
|
| 9 |
from .config import AppConfig
|
| 10 |
|
|
|
|
| 57 |
|
| 58 |
return "\n".join(result)
|
| 59 |
|
| 60 |
+
def get_available_namespaces(token: str) -> List[str]:
|
| 61 |
+
"""
|
| 62 |
+
Returns a list of namespaces (user and organizations) the user can write to.
|
| 63 |
+
First item is always the authenticated user's username.
|
| 64 |
+
"""
|
| 65 |
+
try:
|
| 66 |
+
api = HfApi(token=token)
|
| 67 |
+
info = api.whoami()
|
| 68 |
+
username = info['name']
|
| 69 |
+
orgs = [org['name'] for org in info.get('orgs', [])]
|
| 70 |
+
return [username] + orgs
|
| 71 |
+
except Exception as e:
|
| 72 |
+
print(f"Error fetching namespaces: {e}")
|
| 73 |
+
return []
|
| 74 |
+
|
| 75 |
+
def upload_model_to_hub(folder_path: Path, repo_name: str, token: str, entity: Optional[str] = None) -> str:
|
| 76 |
"""
|
| 77 |
Uploads a local model folder to the Hugging Face Hub.
|
| 78 |
Creates the repository if it doesn't exist.
|
|
|
|
| 80 |
try:
|
| 81 |
api = HfApi(token=token)
|
| 82 |
|
| 83 |
+
# Determine the entity (namespace) to use
|
| 84 |
+
if entity:
|
| 85 |
+
namespace = entity
|
| 86 |
+
else:
|
| 87 |
+
# Fallback to the authenticated user's username
|
| 88 |
+
user_info = api.whoami()
|
| 89 |
+
namespace = user_info['name']
|
| 90 |
|
| 91 |
# Construct the full repo ID
|
| 92 |
+
repo_id = f"{namespace}/{repo_name}"
|
| 93 |
print(f"Preparing to upload to: {repo_id}")
|
| 94 |
|
| 95 |
# Create the repo (safe if it already exists)
|
|
|
|
| 107 |
token=token
|
| 108 |
)
|
| 109 |
tags = info.card_data.tags
|
| 110 |
+
if "embeddinggemma-tuning-lab" not in tags:
|
| 111 |
+
tags.append("embeddinggemma-tuning-lab")
|
| 112 |
+
metadata_update(repo_id, {"tags": tags}, overwrite=True, token=token)
|
| 113 |
|
| 114 |
return f"✅ Success! Model published at: {url}"
|
| 115 |
except Exception as e:
|
|
|
|
| 189 |
# Save the final fine-tuned model
|
| 190 |
trainer.save_model()
|
| 191 |
|
| 192 |
+
print(f"Model saved locally to: {output_dir}")
|
src/ui.py
CHANGED
|
@@ -4,6 +4,7 @@ from datetime import datetime
|
|
| 4 |
|
| 5 |
from .config import AppConfig
|
| 6 |
from .session_manager import HackerNewsFineTuner
|
|
|
|
| 7 |
|
| 8 |
# --- Constants for Labels ---
|
| 9 |
LABEL_FAV = "👍"
|
|
@@ -45,13 +46,27 @@ def on_app_load(app, profile: Optional[gr.OAuthProfile] = None):
|
|
| 45 |
# Return 7 items: App state, Data updates (3), Hub updates (2), Username state (1)
|
| 46 |
return app, stories, labels, text_update, repo_update, push_update, username
|
| 47 |
|
| 48 |
-
def update_repo_preview(
|
| 49 |
-
"""Updates the markdown preview to show '
|
| 50 |
-
if not
|
| 51 |
-
return "⚠️
|
| 52 |
|
| 53 |
clean_repo = repo_name.strip() if repo_name else "..."
|
| 54 |
-
return f"Target Repository: **`{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
def import_wrapper(app, file):
|
| 57 |
return app.import_additional_dataset(file)
|
|
@@ -62,11 +77,12 @@ def export_wrapper(app):
|
|
| 62 |
def download_model_wrapper(app):
|
| 63 |
return app.download_model()
|
| 64 |
|
| 65 |
-
def push_to_hub_wrapper(app, repo_name, oauth_token: Optional[gr.OAuthToken]):
|
| 66 |
if oauth_token is None:
|
| 67 |
return "⚠️ You must be logged in to push to the Hub. Please sign in above."
|
| 68 |
token_str = oauth_token.token
|
| 69 |
-
|
|
|
|
| 70 |
|
| 71 |
def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
|
| 72 |
"""
|
|
@@ -126,7 +142,7 @@ def build_interface() -> gr.Blocks:
|
|
| 126 |
with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
|
| 127 |
gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
|
| 128 |
with gr.Row():
|
| 129 |
-
gr.LoginButton(value="Sign in with Hugging Face")
|
| 130 |
with gr.Column(scale=3):
|
| 131 |
gr.Markdown("")
|
| 132 |
|
|
@@ -200,11 +216,19 @@ def build_interface() -> gr.Blocks:
|
|
| 200 |
gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
|
| 201 |
|
| 202 |
with gr.Row():
|
| 203 |
-
|
| 204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
-
|
| 207 |
|
|
|
|
| 208 |
push_status = gr.Markdown("")
|
| 209 |
|
| 210 |
# --- Step 4: Downloads ---
|
|
@@ -242,15 +266,25 @@ def build_interface() -> gr.Blocks:
|
|
| 242 |
fn=on_app_load,
|
| 243 |
inputs=[session_state],
|
| 244 |
outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
|
| 245 |
-
).then(
|
| 246 |
-
fn=update_repo_preview,
|
| 247 |
-
inputs=[username_state, repo_name_input],
|
| 248 |
-
outputs=[repo_id_preview]
|
| 249 |
).then(
|
| 250 |
fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
|
| 251 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 252 |
|
| 253 |
-
#
|
| 254 |
# ----------------
|
| 255 |
clear_reload_btn.click(
|
| 256 |
fn=lambda: set_interactivity(False), outputs=action_buttons
|
|
@@ -279,7 +313,7 @@ def build_interface() -> gr.Blocks:
|
|
| 279 |
outputs=[reset_counter, labels_state]
|
| 280 |
)
|
| 281 |
|
| 282 |
-
#
|
| 283 |
# ----------------
|
| 284 |
import_file.change(
|
| 285 |
fn=import_wrapper,
|
|
@@ -287,7 +321,7 @@ def build_interface() -> gr.Blocks:
|
|
| 287 |
outputs=[download_status]
|
| 288 |
)
|
| 289 |
|
| 290 |
-
#
|
| 291 |
# ----------------
|
| 292 |
run_training_btn.click(
|
| 293 |
fn=lambda: set_interactivity(False), outputs=action_buttons
|
|
@@ -304,7 +338,7 @@ def build_interface() -> gr.Blocks:
|
|
| 304 |
outputs=[repo_name_input, push_to_hub_btn]
|
| 305 |
)
|
| 306 |
|
| 307 |
-
#
|
| 308 |
# ----------------
|
| 309 |
download_dataset_btn.click(
|
| 310 |
fn=export_wrapper,
|
|
@@ -345,11 +379,17 @@ def build_interface() -> gr.Blocks:
|
|
| 345 |
outputs=[repo_name_input, push_to_hub_btn]
|
| 346 |
)
|
| 347 |
|
| 348 |
-
#
|
| 349 |
# ----------------
|
|
|
|
| 350 |
repo_name_input.change(
|
| 351 |
fn=update_repo_preview,
|
| 352 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
outputs=[repo_id_preview]
|
| 354 |
)
|
| 355 |
|
|
@@ -359,7 +399,7 @@ def build_interface() -> gr.Blocks:
|
|
| 359 |
fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
|
| 360 |
).then(
|
| 361 |
fn=push_to_hub_wrapper,
|
| 362 |
-
inputs=[session_state, repo_name_input],
|
| 363 |
outputs=[push_status]
|
| 364 |
).then(
|
| 365 |
fn=lambda: set_interactivity(True), outputs=action_buttons
|
|
@@ -413,4 +453,4 @@ def build_interface() -> gr.Blocks:
|
|
| 413 |
outputs=[vibe_score, vibe_status, style_thml, session_info_display]
|
| 414 |
)
|
| 415 |
|
| 416 |
-
return demo
|
|
|
|
| 4 |
|
| 5 |
from .config import AppConfig
|
| 6 |
from .session_manager import HackerNewsFineTuner
|
| 7 |
+
from .model_trainer import get_available_namespaces
|
| 8 |
|
| 9 |
# --- Constants for Labels ---
|
| 10 |
LABEL_FAV = "👍"
|
|
|
|
| 46 |
# Return 7 items: App state, Data updates (3), Hub updates (2), Username state (1)
|
| 47 |
return app, stories, labels, text_update, repo_update, push_update, username
|
| 48 |
|
| 49 |
+
def update_repo_preview(entity_name, repo_name):
|
| 50 |
+
"""Updates the markdown preview to show 'entity/repo_name'."""
|
| 51 |
+
if not entity_name:
|
| 52 |
+
return "⚠️ Please select a namespace (User or Org)."
|
| 53 |
|
| 54 |
clean_repo = repo_name.strip() if repo_name else "..."
|
| 55 |
+
return f"Target Repository: **`{entity_name}/{clean_repo}`**"
|
| 56 |
+
|
| 57 |
+
def fetch_orgs_wrapper(oauth_token: Optional[gr.OAuthToken]):
|
| 58 |
+
if not oauth_token:
|
| 59 |
+
return gr.update(choices=[], value=None), "⚠️ Login required to fetch organizations."
|
| 60 |
+
|
| 61 |
+
try:
|
| 62 |
+
namespaces = get_available_namespaces(oauth_token.token)
|
| 63 |
+
if not namespaces:
|
| 64 |
+
return gr.update(choices=[], value=None), "❌ Failed to fetch namespaces."
|
| 65 |
+
|
| 66 |
+
# Default to the first one (username)
|
| 67 |
+
return gr.update(choices=namespaces, value=namespaces[0]), "✅ Organizations loaded."
|
| 68 |
+
except Exception as e:
|
| 69 |
+
return gr.update(choices=[], value=None), f"❌ Error: {str(e)}"
|
| 70 |
|
| 71 |
def import_wrapper(app, file):
|
| 72 |
return app.import_additional_dataset(file)
|
|
|
|
| 77 |
def download_model_wrapper(app):
|
| 78 |
return app.download_model()
|
| 79 |
|
| 80 |
+
def push_to_hub_wrapper(app, entity_name, repo_name, oauth_token: Optional[gr.OAuthToken]):
|
| 81 |
if oauth_token is None:
|
| 82 |
return "⚠️ You must be logged in to push to the Hub. Please sign in above."
|
| 83 |
token_str = oauth_token.token
|
| 84 |
+
# Pass the selected entity
|
| 85 |
+
return app.upload_model(repo_name, token_str, entity=entity_name)
|
| 86 |
|
| 87 |
def training_wrapper(app, stories: List[str], labels: Dict[int, str]):
|
| 88 |
"""
|
|
|
|
| 142 |
with gr.Accordion("0️⃣ Step 0: Sign In (Optional)", open=True):
|
| 143 |
gr.Markdown("Sign in to Hugging Face if you plan to push your fine-tuned model to the Hub later (Step 3).")
|
| 144 |
with gr.Row():
|
| 145 |
+
login_btn = gr.LoginButton(value="Sign in with Hugging Face")
|
| 146 |
with gr.Column(scale=3):
|
| 147 |
gr.Markdown("")
|
| 148 |
|
|
|
|
| 216 |
gr.Markdown("Push your fine-tuned model to your personal Hugging Face account.")
|
| 217 |
|
| 218 |
with gr.Row():
|
| 219 |
+
# Entity (User/Org) Selection
|
| 220 |
+
with gr.Column(scale=1):
|
| 221 |
+
with gr.Row():
|
| 222 |
+
entity_dropdown = gr.Dropdown(label="Owner / Organization", choices=[], interactive=True, scale=4)
|
| 223 |
+
refresh_orgs_btn = gr.Button("🔄", scale=1, size="sm")
|
| 224 |
+
|
| 225 |
+
# Repo Name
|
| 226 |
+
with gr.Column(scale=2):
|
| 227 |
+
repo_name_input = gr.Textbox(label="Target Repository Name", value="my-embeddinggemma-news-vibe", placeholder="e.g., my-embeddinggemma-news-vibe", interactive=False)
|
| 228 |
|
| 229 |
+
push_to_hub_btn = gr.Button("Save to Hugging Face Hub", variant="secondary", interactive=False)
|
| 230 |
|
| 231 |
+
repo_id_preview = gr.Markdown("Target Repository: (Waiting for input...)")
|
| 232 |
push_status = gr.Markdown("")
|
| 233 |
|
| 234 |
# --- Step 4: Downloads ---
|
|
|
|
| 266 |
fn=on_app_load,
|
| 267 |
inputs=[session_state],
|
| 268 |
outputs=[session_state, stories_state, labels_state, output, repo_name_input, push_to_hub_btn, username_state]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 269 |
).then(
|
| 270 |
fn=lambda: [gr.update(interactive=True)]*2, outputs=[clear_reload_btn, run_training_btn]
|
| 271 |
)
|
| 272 |
+
|
| 273 |
+
# 2. Login Trigger -> Auto Fetch Orgs
|
| 274 |
+
# ----------------
|
| 275 |
+
# We can try to fetch orgs automatically if the token is available
|
| 276 |
+
|
| 277 |
+
refresh_orgs_btn.click(
|
| 278 |
+
fn=fetch_orgs_wrapper,
|
| 279 |
+
inputs=[login_btn], # Gr.LoginButton acts as the OAuthToken input in this context? No, usually gr.OAuthToken is implicit or separate
|
| 280 |
+
outputs=[entity_dropdown, push_status]
|
| 281 |
+
).then(
|
| 282 |
+
fn=update_repo_preview,
|
| 283 |
+
inputs=[entity_dropdown, repo_name_input],
|
| 284 |
+
outputs=[repo_id_preview]
|
| 285 |
+
)
|
| 286 |
|
| 287 |
+
# 3. Reset / Refresh / Clear Selections
|
| 288 |
# ----------------
|
| 289 |
clear_reload_btn.click(
|
| 290 |
fn=lambda: set_interactivity(False), outputs=action_buttons
|
|
|
|
| 313 |
outputs=[reset_counter, labels_state]
|
| 314 |
)
|
| 315 |
|
| 316 |
+
# 4. Import Data
|
| 317 |
# ----------------
|
| 318 |
import_file.change(
|
| 319 |
fn=import_wrapper,
|
|
|
|
| 321 |
outputs=[download_status]
|
| 322 |
)
|
| 323 |
|
| 324 |
+
# 5. Run Training
|
| 325 |
# ----------------
|
| 326 |
run_training_btn.click(
|
| 327 |
fn=lambda: set_interactivity(False), outputs=action_buttons
|
|
|
|
| 338 |
outputs=[repo_name_input, push_to_hub_btn]
|
| 339 |
)
|
| 340 |
|
| 341 |
+
# 6. Downloads
|
| 342 |
# ----------------
|
| 343 |
download_dataset_btn.click(
|
| 344 |
fn=export_wrapper,
|
|
|
|
| 379 |
outputs=[repo_name_input, push_to_hub_btn]
|
| 380 |
)
|
| 381 |
|
| 382 |
+
# 7. Push to Hub
|
| 383 |
# ----------------
|
| 384 |
+
# Update preview on Name change or Entity change
|
| 385 |
repo_name_input.change(
|
| 386 |
fn=update_repo_preview,
|
| 387 |
+
inputs=[entity_dropdown, repo_name_input],
|
| 388 |
+
outputs=[repo_id_preview]
|
| 389 |
+
)
|
| 390 |
+
entity_dropdown.change(
|
| 391 |
+
fn=update_repo_preview,
|
| 392 |
+
inputs=[entity_dropdown, repo_name_input],
|
| 393 |
outputs=[repo_id_preview]
|
| 394 |
)
|
| 395 |
|
|
|
|
| 399 |
fn=lambda: gr.update(interactive=False), outputs=push_to_hub_btn
|
| 400 |
).then(
|
| 401 |
fn=push_to_hub_wrapper,
|
| 402 |
+
inputs=[session_state, entity_dropdown, repo_name_input], # Pass entity dropdown
|
| 403 |
outputs=[push_status]
|
| 404 |
).then(
|
| 405 |
fn=lambda: set_interactivity(True), outputs=action_buttons
|
|
|
|
| 453 |
outputs=[vibe_score, vibe_status, style_thml, session_info_display]
|
| 454 |
)
|
| 455 |
|
| 456 |
+
return demo
|