Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| import json | |
| import os | |
| from datetime import datetime | |
| from utils import ( | |
| load_model, | |
| get_hf_token, | |
| simulate_training, | |
| plot_training_metrics, | |
| load_finetuned_model, | |
| save_model | |
| ) | |
| st.title("π₯ Fine-tune the Gemma Model") | |
| # ------------------------------- | |
| # Finetuning Option Selection | |
| # ------------------------------- | |
| finetune_option = st.radio("Select Finetuning Option", ["Fine-tune from scratch", "Refinetune existing model"]) | |
| # ------------------------------- | |
| # Model Selection Logic | |
| # ------------------------------- | |
| selected_model = None | |
| saved_model_path = None | |
| if finetune_option == "Fine-tune from scratch": | |
| # Display Hugging Face model list | |
| model_list = [ | |
| "google/gemma-3-1b-pt", | |
| "google/gemma-3-1b-it", | |
| "google/gemma-3-4b-pt", | |
| "google/gemma-3-4b-it", | |
| "google/gemma-3-12b-pt", | |
| "google/gemma-3-12b-it", | |
| "google/gemma-3-27b-pt", | |
| "google/gemma-3-27b-it" | |
| ] | |
| selected_model = st.selectbox("π οΈ Select Gemma Model to Fine-tune", model_list) | |
| elif finetune_option == "Refinetune existing model": | |
| # Dynamically list all saved models from the /models folder | |
| model_dir = "models" | |
| if os.path.exists(model_dir): | |
| saved_models = [f for f in os.listdir(model_dir) if f.endswith(".pt")] | |
| else: | |
| saved_models = [] | |
| if saved_models: | |
| saved_model_path = st.selectbox("Select a saved model to re-finetune", saved_models) | |
| saved_model_path = os.path.join(model_dir, saved_model_path) | |
| st.success(f"β Selected model for refinement: `{saved_model_path}`") | |
| else: | |
| st.warning("β οΈ No saved models found! Switching to fine-tuning from scratch.") | |
| finetune_option = "Fine-tune from scratch" | |
| # ------------------------------- | |
| # Dataset Selection | |
| # ------------------------------- | |
| st.subheader("π Dataset Selection") | |
| # Dataset source selection | |
| dataset_option = st.radio("Choose dataset:", ["Upload New Dataset", "Use Existing Dataset (`train_data.csv`)"]) | |
| dataset_path = "train_data.csv" | |
| if dataset_option == "Upload New Dataset": | |
| uploaded_file = st.file_uploader("π€ Upload Dataset (CSV or JSON)", type=["csv", "json"]) | |
| if uploaded_file is not None: | |
| # Handle CSV or JSON upload | |
| if uploaded_file.name.endswith(".csv"): | |
| new_data = pd.read_csv(uploaded_file) | |
| elif uploaded_file.name.endswith(".json"): | |
| json_data = json.load(uploaded_file) | |
| new_data = pd.json_normalize(json_data) | |
| else: | |
| st.error("β Unsupported file format. Please upload CSV or JSON.") | |
| st.stop() | |
| # Append or create new dataset | |
| if os.path.exists(dataset_path): | |
| new_data.to_csv(dataset_path, mode='a', index=False, header=False) | |
| st.success(f"β Data appended to `{dataset_path}`!") | |
| else: | |
| new_data.to_csv(dataset_path, index=False) | |
| st.success(f"β Dataset saved as `{dataset_path}`!") | |
| elif dataset_option == "Use Existing Dataset (`train_data.csv`)": | |
| if os.path.exists(dataset_path): | |
| st.success("β Using existing `train_data.csv` for fine-tuning.") | |
| else: | |
| st.error("β `train_data.csv` not found! Please upload a new dataset.") | |
| st.stop() | |
| # ------------------------------- | |
| # Hyperparameters Configuration | |
| # ------------------------------- | |
| learning_rate = st.number_input("π Learning Rate", value=1e-4, format="%.5f") | |
| batch_size = st.number_input("π οΈ Batch Size", value=16, step=1) | |
| epochs = st.number_input("β±οΈ Epochs", value=3, step=1) | |
| # ------------------------------- | |
| # Fine-tuning Execution | |
| # ------------------------------- | |
| if st.button("π Start Fine-tuning"): | |
| st.info(f"Fine-tuning process initiated...") | |
| # Retrieve Hugging Face Token | |
| hf_token = get_hf_token() | |
| # Model loading logic | |
| if finetune_option == "Refinetune existing model" and saved_model_path: | |
| # Load the base model first | |
| tokenizer, model = load_model("google/gemma-3-1b-it", hf_token) | |
| # Load the saved model checkpoint for re-finetuning | |
| model = load_finetuned_model(model, saved_model_path) | |
| if model: | |
| st.success(f"β Loaded saved model: `{saved_model_path}` for refinement!") | |
| else: | |
| st.error("β Failed to load the saved model. Aborting.") | |
| st.stop() | |
| else: | |
| # Fine-tune from scratch (load base model) | |
| if not selected_model: | |
| st.error("β Please select a model to fine-tune.") | |
| st.stop() | |
| tokenizer, model = load_model(selected_model, hf_token) | |
| if model: | |
| st.success(f"β Base model loaded: `{selected_model}`") | |
| else: | |
| st.error("β Failed to load the base model. Aborting.") | |
| st.stop() | |
| # Simulate fine-tuning loop | |
| progress_bar = st.progress(0) | |
| training_placeholder = st.empty() | |
| for epoch, losses, accs in simulate_training(epochs): | |
| fig = plot_training_metrics(epoch, losses, accs) | |
| training_placeholder.pyplot(fig) | |
| progress_bar.progress(epoch / epochs) | |
| # Save fine-tuned model with timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| new_model_name = f"models/fine_tuned_model_{selected_model.replace('/', '_')}_{timestamp}.pt" | |
| # Save the fine-tuned model | |
| saved_model_path = save_model(model, new_model_name) | |
| if saved_model_path: | |
| st.success(f"β Fine-tuning completed! Model saved as `{saved_model_path}`") | |
| # Load the fine-tuned model for immediate inference | |
| model = load_finetuned_model(model, saved_model_path) | |
| if model: | |
| st.success("π οΈ Fine-tuned model loaded and ready for inference!") | |
| else: | |
| st.error("β Failed to load the fine-tuned model for inference.") | |
| else: | |
| st.error("β Failed to save the fine-tuned model.") | |