Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import time | |
| import os | |
| import pandas as pd | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, Trainer, TrainingArguments | |
| from datasets import load_dataset, Dataset | |
| import matplotlib.pyplot as plt | |
| # Set up Streamlit page | |
| st.set_page_config(page_title="AutoTrain AI", page_icon="π", layout="wide") | |
| st.title("AutoTrain AI π") | |
| st.subheader("Train AI models using PyTorch & Hugging Face Transformers") | |
| # Sidebar Configuration | |
| st.sidebar.header("Configuration") | |
| hf_user = st.sidebar.selectbox("Hugging Face User", ["hennings1984"]) | |
| task = st.sidebar.selectbox("Select Task", ["Text Classification", "Sentiment Analysis"]) | |
| hardware = st.sidebar.selectbox("Hardware", ["CPU", "Single GPU", "Multi-GPU", "TPU"]) | |
| model_choice = st.sidebar.selectbox("Choose Model", ["bert-base-uncased", "distilbert-base-uncased", "roberta-base", "None (Custom Model)"]) | |
| dataset_source = st.sidebar.selectbox("Dataset Source", ["glue/sst2", "imdb", "ag_news", "Custom"]) | |
| # Custom Dataset or Predefined Dataset | |
| custom_dataset = None | |
| if dataset_source == "Custom": | |
| file = st.sidebar.file_uploader("Upload Custom Dataset", type=["csv", "json"]) | |
| if file is not None: | |
| custom_dataset = pd.read_csv(file) if file.name.endswith(".csv") else pd.read_json(file) | |
| st.sidebar.write(f"Dataset uploaded with {len(custom_dataset)} rows") | |
| # Training Parameters | |
| epochs = st.sidebar.slider("Number of Epochs", 1, 10, 3) | |
| batch_size = st.sidebar.selectbox("Batch Size", [8, 16, 32, 64], index=1) | |
| learning_rate = st.sidebar.slider("Learning Rate", 1e-6, 1e-3, 2e-5, format="%.6f") | |
| # Check if GPU/TPU is available | |
| device = "cpu" # Default to CPU | |
| if torch.cuda.is_available() and hardware in ["Single GPU", "Multi-GPU"]: | |
| device = "cuda" | |
| elif os.environ.get('COLAB_TPU_ADDR'): # Check if on Google Colab with TPU | |
| try: | |
| import torch_xla | |
| import torch_xla.core.xla_model as xm | |
| device = xm.xla_device() # Set the device to TPU | |
| except ImportError: | |
| st.error("TPU support is available only with 'torch_xla'. Please install it.") | |
| elif hardware == "TPU": | |
| st.error("TPU is not available in this environment. Please use GPU or CPU.") | |
| st.sidebar.write(f"**Using Device:** {device.upper()}") | |
| # Checkpoint Handling | |
| resume_training = st.sidebar.checkbox("Resume Training from Checkpoint") | |
| checkpoint_path = "checkpoint.pth" if resume_training else None | |
| # File Paths | |
| log_file = "train_log.txt" | |
| metrics_file = "metrics.csv" | |
| # Training Buttons | |
| st.write("### Model Training Control") | |
| start_train = st.button("Start Training π") | |
| stop_train = st.button("Stop Training β") | |
| # Live Logs Display | |
| st.write("### Training Logs (Live Updates)") | |
| log_area = st.empty() | |
| # Live Training Metrics | |
| st.write("### Training Metrics π") | |
| # Training Function | |
| def train_model(): | |
| st.success(f"Training started for {task} with {model_choice} on {device.upper()}") | |
| # Load model & tokenizer | |
| if model_choice != "None (Custom Model)": | |
| tokenizer = AutoTokenizer.from_pretrained(model_choice) | |
| model = AutoModelForSequenceClassification.from_pretrained(model_choice, num_labels=2) | |
| else: | |
| # For custom model, assume user will upload a pre-trained model or enter model code | |
| st.error("Custom model support not yet implemented. Please use a base model.") | |
| return | |
| # Load dataset | |
| if dataset_source != "Custom": | |
| dataset = load_dataset(dataset_source) | |
| else: | |
| # Assuming custom dataset is a CSV | |
| dataset = Dataset.from_pandas(custom_dataset) | |
| # Tokenization function | |
| def tokenize_function(examples): | |
| return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=256) | |
| tokenized_datasets = dataset.map(tokenize_function, batched=True) | |
| train_dataset = tokenized_datasets["train"] | |
| eval_dataset = tokenized_datasets.get("validation", tokenized_datasets["test"]) | |
| # Checkpoint Handling | |
| if resume_training and os.path.exists(checkpoint_path): | |
| model.load_state_dict(torch.load(checkpoint_path)) | |
| # Move model to device | |
| model.to(device) | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| evaluation_strategy="epoch", | |
| logging_dir="./logs", | |
| logging_steps=5, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, | |
| num_train_epochs=epochs, | |
| save_strategy="epoch", | |
| learning_rate=learning_rate | |
| ) | |
| # Trainer setup | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=eval_dataset, | |
| ) | |
| # Progress bar for training | |
| progress_bar = st.progress(0) | |
| # Training Loop | |
| metrics = [] | |
| with open(log_file, "w") as log_file_handle: | |
| log_file_handle.write("Starting training...\n") | |
| log_file_handle.flush() | |
| for epoch in range(epochs): | |
| trainer.train() | |
| results = trainer.evaluate() | |
| # Save Checkpoint | |
| torch.save(model.state_dict(), f"checkpoint_epoch_{epoch+1}.pth") | |
| # Log results | |
| log_text = f"Epoch {epoch+1}: Loss = {results['eval_loss']:.4f}, Accuracy = {results.get('eval_accuracy', 0):.4f}\n" | |
| log_file_handle.write(log_text) | |
| log_file_handle.flush() | |
| # Save metrics | |
| metrics.append({"epoch": epoch+1, "loss": results["eval_loss"], "accuracy": results.get("eval_accuracy", 0)}) | |
| pd.DataFrame(metrics).to_csv(metrics_file, index=False) | |
| # Update logs & metrics in UI | |
| log_area.text(log_text) | |
| st.line_chart(pd.DataFrame(metrics).set_index("epoch")) | |
| # Update progress bar | |
| progress = (epoch + 1) / epochs | |
| progress_bar.progress(progress) | |
| time.sleep(2) | |
| # Display final results | |
| st.write("### Final Results π") | |
| final_metrics = pd.DataFrame(metrics) | |
| st.line_chart(final_metrics.set_index("epoch")) | |
| st.write(final_metrics) | |
| # Start Training | |
| if start_train: | |
| train_model() | |
| # Stop Training | |
| if stop_train: | |
| st.warning("Training stopped manually.") |