easy-green-lab / app.py
SuperPauly's picture
Update Gradio app with multiple files
9c928ab verified
import gradio as gr
import torch
from huggingface_hub import HfApi, login, create_repo, whoami
from transformers import AutoModel, AutoTokenizer, AdamW, get_scheduler
from datasets import load_dataset
from torch.utils.data import DataLoader
import spaces
import time
import os
from typing import Optional, Tuple, Dict, Any
import pandas as pd
from utils import (
load_embedding_model,
load_huggingface_dataset,
prepare_dataset_for_training,
train_model_on_zero_gpu,
save_model_to_hub
)
from config import APP_CONFIG
# Initialize session state variables
session_state = {
"logged_in": False,
"hf_token": None,
"username": None,
"model": None,
"tokenizer": None,
"dataset": None,
"model_loaded": False,
"dataset_loaded": False,
"training_complete": False,
"training_history": []
}
def update_status():
"""Update the status display based on current session state"""
status_items = []
if session_state["logged_in"]:
status_items.append(("βœ… Logged In", "success"))
status_items.append((f"User: {session_state['username']}", "info"))
else:
status_items.append(("❌ Not Logged In", "error"))
if session_state["model_loaded"]:
status_items.append(("βœ… Model Loaded", "success"))
else:
status_items.append(("⏳ No Model", "warning"))
if session_state["dataset_loaded"]:
status_items.append(("βœ… Dataset Loaded", "success"))
else:
status_items.append(("⏳ No Dataset", "warning"))
if session_state["training_complete"]:
status_items.append(("πŸŽ‰ Training Complete", "success"))
return status_items
def login_to_huggingface(token: str) -> Tuple[str, str]:
"""Login to Hugging Face with provided token"""
try:
login(token=token, add_to_git_credential=True)
user_info = whoami(token=token)
session_state["logged_in"] = True
session_state["hf_token"] = token
session_state["username"] = user_info["name"]
status = update_status()
return "βœ… Successfully logged in to Hugging Face!", format_status(status)
except Exception as e:
return f"❌ Login failed: {str(e)}", format_status(update_status())
def format_status(status_items):
"""Format status items for display"""
if not status_items:
return "No status available"
status_html = "<div style='font-family: monospace;'>"
for status, status_type in status_items:
color = {
"success": "green",
"error": "red",
"warning": "orange",
"info": "blue"
}.get(status_type, "black")
status_html += f"<div style='color: {color}; margin: 2px 0;'>{status}</div>"
status_html += "</div>"
return status_html
def load_model(model_url: str) -> Tuple[str, str]:
"""Load embedding model from Hugging Face"""
if not session_state["logged_in"]:
return "❌ Please login first!", format_status(update_status())
try:
with gr.Blocks() as demo:
model, tokenizer = load_embedding_model(model_url)
session_state["model"] = model
session_state["tokenizer"] = tokenizer
session_state["model_loaded"] = True
model_info = f"Model: {model.__class__.__name__}\n"
model_info += f"Parameters: {sum(p.numel() for p in model.parameters()):,}\n"
model_info += f"Device: {'GPU' if torch.cuda.is_available() else 'CPU'}"
status = update_status()
return f"βœ… Model loaded successfully!\n\n{model_info}", format_status(status)
except Exception as e:
return f"❌ Failed to load model: {str(e)}", format_status(update_status())
def load_dataset(dataset_name: str, split: str = "train") -> Tuple[str, str]:
"""Load dataset from Hugging Face"""
if not session_state["logged_in"]:
return "❌ Please login first!", format_status(update_status())
try:
dataset = load_huggingface_dataset(dataset_name, split)
session_state["dataset"] = dataset
session_state["dataset_loaded"] = True
dataset_info = f"Dataset: {dataset_name}\n"
dataset_info += f"Split: {split}\n"
dataset_info += f"Size: {len(dataset):,} samples\n"
if hasattr(dataset, 'column_names'):
dataset_info += f"Columns: {', '.join(dataset.column_names)}"
status = update_status()
return f"βœ… Dataset loaded successfully!\n\n{dataset_info}", format_status(status)
except Exception as e:
return f"❌ Failed to load dataset: {str(e)}", format_status(update_status())
@spaces.GPU(duration=300) # 5 minutes for training
def start_training(
epochs: int,
batch_size: int,
learning_rate: float,
warmup_steps: int,
use_zero_gpu: bool,
repo_name: str,
create_repo: bool,
private_repo: bool,
upload_to_hub: bool
) -> Tuple[str, str]:
"""Start training the embedding model"""
# Check prerequisites
if not session_state["logged_in"]:
return "❌ Please login first!", format_status(update_status())
if not session_state["model_loaded"]:
return "❌ Please load a model first!", format_status(update_status())
if not session_state["dataset_loaded"]:
return "❌ Please load a dataset first!", format_status(update_status())
if upload_to_hub and not repo_name:
return "❌ Please provide a repository name for upload!", format_status(update_status())
try:
# Prepare dataset
dataset = prepare_dataset_for_training(
session_state["dataset"],
session_state["tokenizer"]
)
# Start training
trained_model, training_history = train_model_on_zero_gpu(
model=session_state["model"],
tokenizer=session_state["tokenizer"],
dataset=dataset,
epochs=epochs,
batch_size=batch_size,
learning_rate=learning_rate,
warmup_steps=warmup_steps,
use_zero_gpu=use_zero_gpu
)
session_state["model"] = trained_model
session_state["training_complete"] = True
session_state["training_history"] = training_history
# Save model locally
local_path = f"./trained_model_{int(time.time())}"
trained_model.save_pretrained(local_path)
session_state["tokenizer"].save_pretrained(local_path)
result_msg = f"πŸŽ‰ Training completed successfully!\n\n"
result_msg += f"Model saved locally to: {local_path}\n"
result_msg += f"Training epochs: {epochs}\n"
result_msg += f"Final loss: {training_history[-1]['loss']:.4f}\n"
# Upload to Hub if requested
if upload_to_hub and repo_name:
try:
if create_repo:
create_repo(
repo_id=repo_name,
token=session_state["hf_token"],
private=private_repo,
repo_type="model",
exist_ok=True
)
save_model_to_hub(
model=trained_model,
tokenizer=session_state["tokenizer"],
repo_id=repo_name,
token=session_state["hf_token"],
private=private_repo
)
result_msg += f"\nβœ… Model uploaded to Hub: https://huggingface.co/{repo_name}"
except Exception as e:
result_msg += f"\n⚠️ Upload to Hub failed: {str(e)}"
status = update_status()
return result_msg, format_status(status)
except Exception as e:
return f"❌ Training failed: {str(e)}", format_status(update_status())
def get_training_history():
"""Get training history as a dataframe"""
if not session_state["training_history"]:
return pd.DataFrame(columns=["Epoch", "Loss", "Learning Rate"])
return pd.DataFrame(session_state["training_history"])
def create_interface():
"""Create the Gradio interface"""
with gr.Blocks(
title="Embedding Model Trainer",
theme=gr.themes.Soft(),
css="""
.main-header {
text-align: center;
font-size: 2.5em;
font-weight: bold;
color: #1f77b4;
margin-bottom: 1em;
}
.status-box {
padding: 10px;
border-radius: 5px;
background-color: #f0f0f0;
font-family: monospace;
min-height: 100px;
}
"""
) as demo:
gr.HTML('<div class="main-header">πŸ€– Embedding Model Trainer</div>')
gr.HTML('<p style="text-align: center;">Built with <a href="https://huggingface.co/spaces/akhaliq/anycoder" target="_blank">anycoder</a></p>')
with gr.Row():
with gr.Column(scale=1):
# Login Section
gr.Markdown("## πŸ” Hugging Face Login")
with gr.Group():
token_input = gr.Textbox(
label="Hugging Face Token",
type="password",
placeholder="Enter your HF token...",
info="Get your token from https://huggingface.co/settings/tokens"
)
login_btn = gr.Button("πŸ”‘ Login", variant="primary")
login_output = gr.Textbox(label="Login Status", interactive=False)
# Status Display
gr.Markdown("## πŸ“Š Status")
status_display = gr.HTML(format_status(update_status()))
# Model Loading
gr.Markdown("## πŸ“₯ Load Model")
with gr.Group():
model_url = gr.Textbox(
label="Model URL/Name",
placeholder="e.g., sentence-transformers/all-MiniLM-L6-v2",
info="Enter Hugging Face model repository URL or name"
)
load_model_btn = gr.Button("πŸ“₯ Load Model", variant="secondary")
model_output = gr.Textbox(label="Model Status", interactive=False, lines=5)
# Dataset Loading
gr.Markdown("## πŸ“Š Load Dataset")
with gr.Group():
dataset_name = gr.Textbox(
label="Dataset Name",
placeholder="e.g., imdb",
info="Enter Hugging Face dataset name"
)
dataset_split = gr.Dropdown(
choices=["train", "test", "validation"],
value="train",
label="Dataset Split"
)
load_dataset_btn = gr.Button("πŸ“Š Load Dataset", variant="secondary")
dataset_output = gr.Textbox(label="Dataset Status", interactive=False, lines=5)
with gr.Column(scale=2):
# Training Configuration
gr.Markdown("## βš™οΈ Training Configuration")
with gr.Row():
with gr.Column():
epochs = gr.Number(
label="Training Epochs",
value=3,
minimum=1,
maximum=100,
step=1
)
batch_size = gr.Number(
label="Batch Size",
value=16,
minimum=1,
maximum=128,
step=1
)
with gr.Column():
learning_rate = gr.Number(
label="Learning Rate",
value=2e-5,
minimum=1e-6,
maximum=1e-1,
format="%.6f"
)
warmup_steps = gr.Number(
label="Warmup Steps",
value=100,
minimum=0,
maximum=1000,
step=10
)
use_zero_gpu = gr.Checkbox(
label="Use Zero GPU",
value=True,
info="Enable Zero GPU for training (recommended)"
)
# Repository Settings
gr.Markdown("## πŸ“€ Repository Settings")
with gr.Row():
with gr.Column():
repo_name = gr.Textbox(
label="Repository Name",
placeholder="my-fine-tuned-model",
info="Name for your model repository"
)
create_repo = gr.Checkbox(
label="Create New Repository",
value=True,
info="Create a new repository if it doesn't exist"
)
with gr.Column():
private_repo = gr.Checkbox(
label="Private Repository",
value=False,
info="Make the repository private"
)
upload_to_hub = gr.Checkbox(
label="Upload to Hub",
value=True,
info="Upload trained model to Hugging Face Hub"
)
# Training Button
train_btn = gr.Button(
"πŸš€ Start Training",
variant="primary",
size="lg"
)
training_output = gr.Textbox(
label="Training Results",
interactive=False,
lines=8
)
# Training History
with gr.Row():
gr.Markdown("## πŸ“ˆ Training History")
history_df = gr.Dataframe(
label="Training Metrics",
value=get_training_history(),
interactive=False
)
# Event Handlers
login_btn.click(
login_to_huggingface,
inputs=[token_input],
outputs=[login_output, status_display]
)
load_model_btn.click(
load_model,
inputs=[model_url],
outputs=[model_output, status_display]
)
load_dataset_btn.click(
load_dataset,
inputs=[dataset_name, dataset_split],
outputs=[dataset_output, status_display]
)
train_btn.click(
start_training,
inputs=[
epochs,
batch_size,
learning_rate,
warmup_steps,
use_zero_gpu,
repo_name,
create_repo,
private_repo,
upload_to_hub
],
outputs=[training_output, status_display]
)
# Auto-refresh status and history
demo.load(
get_training_history,
outputs=[history_df],
every=5
)
demo.load(
lambda: format_status(update_status()),
outputs=[status_display],
every=5
)
return demo
# Create and launch the interface
if __name__ == "__main__":
demo = create_interface()
demo.launch(
share=True,
show_error=True,
show_api=True
)