Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from huggingface_hub import HfApi, login, create_repo, whoami | |
| from transformers import AutoModel, AutoTokenizer, AdamW, get_scheduler | |
| from datasets import load_dataset | |
| from torch.utils.data import DataLoader | |
| import spaces | |
| import time | |
| import os | |
| from typing import Optional, Tuple, Dict, Any | |
| import pandas as pd | |
| from utils import ( | |
| load_embedding_model, | |
| load_huggingface_dataset, | |
| prepare_dataset_for_training, | |
| train_model_on_zero_gpu, | |
| save_model_to_hub | |
| ) | |
| from config import APP_CONFIG | |
| # Initialize session state variables | |
| session_state = { | |
| "logged_in": False, | |
| "hf_token": None, | |
| "username": None, | |
| "model": None, | |
| "tokenizer": None, | |
| "dataset": None, | |
| "model_loaded": False, | |
| "dataset_loaded": False, | |
| "training_complete": False, | |
| "training_history": [] | |
| } | |
| def update_status(): | |
| """Update the status display based on current session state""" | |
| status_items = [] | |
| if session_state["logged_in"]: | |
| status_items.append(("β Logged In", "success")) | |
| status_items.append((f"User: {session_state['username']}", "info")) | |
| else: | |
| status_items.append(("β Not Logged In", "error")) | |
| if session_state["model_loaded"]: | |
| status_items.append(("β Model Loaded", "success")) | |
| else: | |
| status_items.append(("β³ No Model", "warning")) | |
| if session_state["dataset_loaded"]: | |
| status_items.append(("β Dataset Loaded", "success")) | |
| else: | |
| status_items.append(("β³ No Dataset", "warning")) | |
| if session_state["training_complete"]: | |
| status_items.append(("π Training Complete", "success")) | |
| return status_items | |
| def login_to_huggingface(token: str) -> Tuple[str, str]: | |
| """Login to Hugging Face with provided token""" | |
| try: | |
| login(token=token, add_to_git_credential=True) | |
| user_info = whoami(token=token) | |
| session_state["logged_in"] = True | |
| session_state["hf_token"] = token | |
| session_state["username"] = user_info["name"] | |
| status = update_status() | |
| return "β Successfully logged in to Hugging Face!", format_status(status) | |
| except Exception as e: | |
| return f"β Login failed: {str(e)}", format_status(update_status()) | |
| def format_status(status_items): | |
| """Format status items for display""" | |
| if not status_items: | |
| return "No status available" | |
| status_html = "<div style='font-family: monospace;'>" | |
| for status, status_type in status_items: | |
| color = { | |
| "success": "green", | |
| "error": "red", | |
| "warning": "orange", | |
| "info": "blue" | |
| }.get(status_type, "black") | |
| status_html += f"<div style='color: {color}; margin: 2px 0;'>{status}</div>" | |
| status_html += "</div>" | |
| return status_html | |
| def load_model(model_url: str) -> Tuple[str, str]: | |
| """Load embedding model from Hugging Face""" | |
| if not session_state["logged_in"]: | |
| return "β Please login first!", format_status(update_status()) | |
| try: | |
| with gr.Blocks() as demo: | |
| model, tokenizer = load_embedding_model(model_url) | |
| session_state["model"] = model | |
| session_state["tokenizer"] = tokenizer | |
| session_state["model_loaded"] = True | |
| model_info = f"Model: {model.__class__.__name__}\n" | |
| model_info += f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n" | |
| model_info += f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}" | |
| status = update_status() | |
| return f"β Model loaded successfully!\n\n{model_info}", format_status(status) | |
| except Exception as e: | |
| return f"β Failed to load model: {str(e)}", format_status(update_status()) | |
| def load_dataset(dataset_name: str, split: str = "train") -> Tuple[str, str]: | |
| """Load dataset from Hugging Face""" | |
| if not session_state["logged_in"]: | |
| return "β Please login first!", format_status(update_status()) | |
| try: | |
| dataset = load_huggingface_dataset(dataset_name, split) | |
| session_state["dataset"] = dataset | |
| session_state["dataset_loaded"] = True | |
| dataset_info = f"Dataset: {dataset_name}\n" | |
| dataset_info += f"Split: {split}\n" | |
| dataset_info += f"Size: {len(dataset):,} samples\n" | |
| if hasattr(dataset, 'column_names'): | |
| dataset_info += f"Columns: {', '.join(dataset.column_names)}" | |
| status = update_status() | |
| return f"β Dataset loaded successfully!\n\n{dataset_info}", format_status(status) | |
| except Exception as e: | |
| return f"β Failed to load dataset: {str(e)}", format_status(update_status()) | |
| # 5 minutes for training | |
| def start_training( | |
| epochs: int, | |
| batch_size: int, | |
| learning_rate: float, | |
| warmup_steps: int, | |
| use_zero_gpu: bool, | |
| repo_name: str, | |
| create_repo: bool, | |
| private_repo: bool, | |
| upload_to_hub: bool | |
| ) -> Tuple[str, str]: | |
| """Start training the embedding model""" | |
| # Check prerequisites | |
| if not session_state["logged_in"]: | |
| return "β Please login first!", format_status(update_status()) | |
| if not session_state["model_loaded"]: | |
| return "β Please load a model first!", format_status(update_status()) | |
| if not session_state["dataset_loaded"]: | |
| return "β Please load a dataset first!", format_status(update_status()) | |
| if upload_to_hub and not repo_name: | |
| return "β Please provide a repository name for upload!", format_status(update_status()) | |
| try: | |
| # Prepare dataset | |
| dataset = prepare_dataset_for_training( | |
| session_state["dataset"], | |
| session_state["tokenizer"] | |
| ) | |
| # Start training | |
| trained_model, training_history = train_model_on_zero_gpu( | |
| model=session_state["model"], | |
| tokenizer=session_state["tokenizer"], | |
| dataset=dataset, | |
| epochs=epochs, | |
| batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| warmup_steps=warmup_steps, | |
| use_zero_gpu=use_zero_gpu | |
| ) | |
| session_state["model"] = trained_model | |
| session_state["training_complete"] = True | |
| session_state["training_history"] = training_history | |
| # Save model locally | |
| local_path = f"./trained_model_{int(time.time())}" | |
| trained_model.save_pretrained(local_path) | |
| session_state["tokenizer"].save_pretrained(local_path) | |
| result_msg = f"π Training completed successfully!\n\n" | |
| result_msg += f"Model saved locally to: {local_path}\n" | |
| result_msg += f"Training epochs: {epochs}\n" | |
| result_msg += f"Final loss: {training_history[-1]['loss']:.4f}\n" | |
| # Upload to Hub if requested | |
| if upload_to_hub and repo_name: | |
| try: | |
| if create_repo: | |
| create_repo( | |
| repo_id=repo_name, | |
| token=session_state["hf_token"], | |
| private=private_repo, | |
| repo_type="model", | |
| exist_ok=True | |
| ) | |
| save_model_to_hub( | |
| model=trained_model, | |
| tokenizer=session_state["tokenizer"], | |
| repo_id=repo_name, | |
| token=session_state["hf_token"], | |
| private=private_repo | |
| ) | |
| result_msg += f"\nβ Model uploaded to Hub: https://huggingface.co/{repo_name}" | |
| except Exception as e: | |
| result_msg += f"\nβ οΈ Upload to Hub failed: {str(e)}" | |
| status = update_status() | |
| return result_msg, format_status(status) | |
| except Exception as e: | |
| return f"β Training failed: {str(e)}", format_status(update_status()) | |
| def get_training_history(): | |
| """Get training history as a dataframe""" | |
| if not session_state["training_history"]: | |
| return pd.DataFrame(columns=["Epoch", "Loss", "Learning Rate"]) | |
| return pd.DataFrame(session_state["training_history"]) | |
| def create_interface(): | |
| """Create the Gradio interface""" | |
| with gr.Blocks( | |
| title="Embedding Model Trainer", | |
| theme=gr.themes.Soft(), | |
| css=""" | |
| .main-header { | |
| text-align: center; | |
| font-size: 2.5em; | |
| font-weight: bold; | |
| color: #1f77b4; | |
| margin-bottom: 1em; | |
| } | |
| .status-box { | |
| padding: 10px; | |
| border-radius: 5px; | |
| background-color: #f0f0f0; | |
| font-family: monospace; | |
| min-height: 100px; | |
| } | |
| """ | |
| ) as demo: | |
| gr.HTML('<div class="main-header">π€ Embedding Model Trainer</div>') | |
| gr.HTML('<p style="text-align: center;">Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>') | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Login Section | |
| gr.Markdown("## π Hugging Face Login") | |
| with gr.Group(): | |
| token_input = gr.Textbox( | |
| label="Hugging Face Token", | |
| type="password", | |
| placeholder="Enter your HF token...", | |
| info="Get your token from https://huggingface.co/settings/tokens" | |
| ) | |
| login_btn = gr.Button("π Login", variant="primary") | |
| login_output = gr.Textbox(label="Login Status", interactive=False) | |
| # Status Display | |
| gr.Markdown("## π Status") | |
| status_display = gr.HTML(format_status(update_status())) | |
| # Model Loading | |
| gr.Markdown("## π₯ Load Model") | |
| with gr.Group(): | |
| model_url = gr.Textbox( | |
| label="Model URL/Name", | |
| placeholder="e.g., sentence-transformers/all-MiniLM-L6-v2", | |
| info="Enter Hugging Face model repository URL or name" | |
| ) | |
| load_model_btn = gr.Button("π₯ Load Model", variant="secondary") | |
| model_output = gr.Textbox(label="Model Status", interactive=False, lines=5) | |
| # Dataset Loading | |
| gr.Markdown("## π Load Dataset") | |
| with gr.Group(): | |
| dataset_name = gr.Textbox( | |
| label="Dataset Name", | |
| placeholder="e.g., imdb", | |
| info="Enter Hugging Face dataset name" | |
| ) | |
| dataset_split = gr.Dropdown( | |
| choices=["train", "test", "validation"], | |
| value="train", | |
| label="Dataset Split" | |
| ) | |
| load_dataset_btn = gr.Button("π Load Dataset", variant="secondary") | |
| dataset_output = gr.Textbox(label="Dataset Status", interactive=False, lines=5) | |
| with gr.Column(scale=2): | |
| # Training Configuration | |
| gr.Markdown("## βοΈ Training Configuration") | |
| with gr.Row(): | |
| with gr.Column(): | |
| epochs = gr.Number( | |
| label="Training Epochs", | |
| value=3, | |
| minimum=1, | |
| maximum=100, | |
| step=1 | |
| ) | |
| batch_size = gr.Number( | |
| label="Batch Size", | |
| value=16, | |
| minimum=1, | |
| maximum=128, | |
| step=1 | |
| ) | |
| with gr.Column(): | |
| learning_rate = gr.Number( | |
| label="Learning Rate", | |
| value=2e-5, | |
| minimum=1e-6, | |
| maximum=1e-1, | |
| format="%.6f" | |
| ) | |
| warmup_steps = gr.Number( | |
| label="Warmup Steps", | |
| value=100, | |
| minimum=0, | |
| maximum=1000, | |
| step=10 | |
| ) | |
| use_zero_gpu = gr.Checkbox( | |
| label="Use Zero GPU", | |
| value=True, | |
| info="Enable Zero GPU for training (recommended)" | |
| ) | |
| # Repository Settings | |
| gr.Markdown("## π€ Repository Settings") | |
| with gr.Row(): | |
| with gr.Column(): | |
| repo_name = gr.Textbox( | |
| label="Repository Name", | |
| placeholder="my-fine-tuned-model", | |
| info="Name for your model repository" | |
| ) | |
| create_repo = gr.Checkbox( | |
| label="Create New Repository", | |
| value=True, | |
| info="Create a new repository if it doesn't exist" | |
| ) | |
| with gr.Column(): | |
| private_repo = gr.Checkbox( | |
| label="Private Repository", | |
| value=False, | |
| info="Make the repository private" | |
| ) | |
| upload_to_hub = gr.Checkbox( | |
| label="Upload to Hub", | |
| value=True, | |
| info="Upload trained model to Hugging Face Hub" | |
| ) | |
| # Training Button | |
| train_btn = gr.Button( | |
| "π Start Training", | |
| variant="primary", | |
| size="lg" | |
| ) | |
| training_output = gr.Textbox( | |
| label="Training Results", | |
| interactive=False, | |
| lines=8 | |
| ) | |
| # Training History | |
| with gr.Row(): | |
| gr.Markdown("## π Training History") | |
| history_df = gr.Dataframe( | |
| label="Training Metrics", | |
| value=get_training_history(), | |
| interactive=False | |
| ) | |
| # Event Handlers | |
| login_btn.click( | |
| login_to_huggingface, | |
| inputs=[token_input], | |
| outputs=[login_output, status_display] | |
| ) | |
| load_model_btn.click( | |
| load_model, | |
| inputs=[model_url], | |
| outputs=[model_output, status_display] | |
| ) | |
| load_dataset_btn.click( | |
| load_dataset, | |
| inputs=[dataset_name, dataset_split], | |
| outputs=[dataset_output, status_display] | |
| ) | |
| train_btn.click( | |
| start_training, | |
| inputs=[ | |
| epochs, | |
| batch_size, | |
| learning_rate, | |
| warmup_steps, | |
| use_zero_gpu, | |
| repo_name, | |
| create_repo, | |
| private_repo, | |
| upload_to_hub | |
| ], | |
| outputs=[training_output, status_display] | |
| ) | |
| # Auto-refresh status and history | |
| demo.load( | |
| get_training_history, | |
| outputs=[history_df], | |
| every=5 | |
| ) | |
| demo.load( | |
| lambda: format_status(update_status()), | |
| outputs=[status_display], | |
| every=5 | |
| ) | |
| return demo | |
| # Create and launch the interface | |
| if __name__ == "__main__": | |
| demo = create_interface() | |
| demo.launch( | |
| share=True, | |
| show_error=True, | |
| show_api=True | |
| ) |