Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pandas as pd | |
| import torch | |
| from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments, BertConfig | |
| from sklearn.model_selection import train_test_split | |
| from datasets import Dataset | |
| import os | |
| # Title of the app | |
| st.title("Train a Model with Your Dataset") | |
| # File uploader widget | |
| uploaded_file = st.file_uploader("Upload a CSV file with text and labels", type=["csv"]) | |
| # Checkbox to select if user wants to use a pre-trained model or a custom model | |
| use_base_model = st.checkbox("Use Pre-trained Base Model (BERT)", value=True) | |
| # Checking if file is uploaded | |
| if uploaded_file is not None: | |
| # Load the CSV file into a DataFrame | |
| df = pd.read_csv(uploaded_file) | |
| # Show data preview and ensure necessary columns exist | |
| st.write("Uploaded Dataset:") | |
| st.write(df.head()) | |
| if 'text' not in df.columns or 'label' not in df.columns: | |
| st.error("The CSV file must contain 'text' and 'label' columns!") | |
| else: | |
| # Prepare dataset for training | |
| dataset = Dataset.from_pandas(df) | |
| train_data, test_data = train_test_split(df, test_size=0.2) | |
| # Convert DataFrame to Hugging Face dataset | |
| train_dataset = Dataset.from_pandas(train_data) | |
| test_dataset = Dataset.from_pandas(test_data) | |
| # Load pre-trained BERT tokenizer | |
| tokenizer = BertTokenizer.from_pretrained('bert-base-uncased') | |
| # Tokenization function | |
| def tokenize_function(examples): | |
| return tokenizer(examples['text'], padding="max_length", truncation=True) | |
| # Tokenize the datasets | |
| train_dataset = train_dataset.map(tokenize_function, batched=True) | |
| test_dataset = test_dataset.map(tokenize_function, batched=True) | |
| # Conditional logic based on checkbox | |
| if use_base_model: | |
| # Load pre-trained BERT model for sequence classification | |
| model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2) | |
| else: | |
| # Create a custom model (no pre-trained weights) | |
| config = BertConfig(num_labels=2) | |
| model = BertForSequenceClassification(config) | |
| # Define training arguments | |
| training_args = TrainingArguments( | |
| output_dir='./results', | |
| num_train_epochs=3, | |
| per_device_train_batch_size=8, | |
| per_device_eval_batch_size=8, | |
| logging_dir='./logs', | |
| evaluation_strategy='epoch', | |
| save_strategy='epoch', | |
| logging_steps=100, | |
| report_to="none", # To prevent logging to external services like wandb | |
| ) | |
| # Initialize Trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=test_dataset | |
| ) | |
| # Streamlit progress bar | |
| progress_bar = st.progress(0) | |
| progress_text = st.empty() | |
| # Callback function to update progress bar | |
| def update_progress_bar(args): | |
| step = args["step"] | |
| total_steps = args["max_steps"] | |
| progress = step / total_steps * 100 | |
| progress_bar.progress(progress) | |
| progress_text.text(f"Training Progress: {int(progress)}%") | |
| # Training loop with progress updates | |
| if st.button('Start Training'): | |
| with st.spinner('Training in progress...'): | |
| trainer.add_callback(update_progress_bar) | |
| trainer.train() | |
| st.success('Training Complete!') | |
| # Save the model after training | |
| model_path = "./trained_model" | |
| model.save_pretrained(model_path) | |
| # Calculate and display model size | |
| model_size = sum(os.path.getsize(f) for f in os.listdir(model_path) if os.path.isfile(f)) | |
| st.write(f"Trained model size: {model_size / (1024 * 1024):.2f} MB") | |
| # Optionally, allow the user to download the trained model | |
| st.download_button( | |
| label="Download Trained Model", | |
| data=model.state_dict(), | |
| file_name="trained_model.pth", | |
| mime="application/octet-stream" | |
| ) |