Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| import torch | |
| import random | |
| from transformers import ( | |
| GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling | |
| ) | |
| from datasets import Dataset | |
| from huggingface_hub import HfApi | |
| import plotly.graph_objects as go | |
| import time | |
| from datetime import datetime | |
| from typing import Dict, List, Any | |
| import pandas as pd # Added pandas import | |
| # Cyberpunk and Loading Animation Styling | |
| def setup_cyberpunk_style(): | |
| st.markdown(""" | |
| <style> | |
| body { | |
| background-color: #000; | |
| color: #00ff00; | |
| font-family: 'Monaco', monospace; | |
| } | |
| .stButton>button { | |
| color: #00ff00; | |
| border: 1px solid #00ff00; | |
| background-color: transparent; | |
| transition: 0.3s ease-in-out; | |
| } | |
| .stButton>button:hover { | |
| color: #000; | |
| background-color: #00ff00; | |
| } | |
| .stTextInput>div>div>input, .stSelectbox>div>div>div>div, .stTextArea>div>div>textarea { | |
| background-color: #111; | |
| color: #00ff00; | |
| border: 1px solid #00ff00; | |
| } | |
| .stSlider>div>div>div>div, .stNumberInput>div>div>div>div>input { | |
| background-color: #111; | |
| color: #00ff00; | |
| } | |
| .stMarkdown, .stText, .stDataFrame { | |
| color: #00ff00; | |
| } | |
| .stAlert { | |
| background-color: #111; | |
| color: #00ff00; | |
| border: 1px solid #00ff00; | |
| } | |
| .stProgress>div>div>div { | |
| background-color: #00ff00; | |
| } | |
| /* Loading animation */ | |
| .st-loader { | |
| border: 8px solid #111; | |
| border-top: 8px solid #00ff00; | |
| border-radius: 50%; | |
| width: 60px; | |
| height: 60px; | |
| animation: spin 1s linear infinite; | |
| } | |
| @keyframes spin { | |
| 0% { transform: rotate(0deg); } | |
| 100% { transform: rotate(360deg); } | |
| } | |
| /* Plotly chart styling */ | |
| .modebar { | |
| background-color: #111 !important; | |
| border: 1px solid #00ff00 !important; | |
| } | |
| .modebar-btn { | |
| color: #00ff00 !important; | |
| background-color: transparent !important; | |
| } | |
| .modebar-btn:hover { | |
| background-color: #00ff00 !important; | |
| color: #000 !important; | |
| } | |
| .plotly-notifier { | |
| background-color: #111 !important; | |
| color: #00ff00 !important; | |
| border: 1px solid #00ff00 !important; | |
| } | |
| .plotly-notifier a { | |
| color: #00ff00 !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # Prepare Dataset Function with Padding Token Fix | |
| def prepare_dataset(data, tokenizer, block_size=128): | |
| tokenizer.pad_token = tokenizer.eos_token | |
| def tokenize_function(examples): | |
| return tokenizer(examples['text'], truncation=True, max_length=block_size, padding='max_length') | |
| raw_dataset = Dataset.from_dict({'text': data}) | |
| tokenized_dataset = raw_dataset.map(tokenize_function, batched=True, remove_columns=['text']) | |
| tokenized_dataset = tokenized_dataset.map(lambda examples: {'labels': examples['input_ids']}, batched=True) | |
| tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels']) | |
| return tokenized_dataset | |
| # Define Model Initialization | |
| def initialize_model(model_name="gpt2"): | |
| model = GPT2LMHeadModel.from_pretrained(model_name) | |
| tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return model, tokenizer | |
| # Load Dataset Function with Uploaded File Option | |
| def load_dataset(data_source="demo", tokenizer=None, uploaded_file=None): | |
| if data_source == "demo": | |
| data = [ | |
| "In the neon-lit streets of Neo-Tokyo, a lone hacker fights against the oppressive megacorporations.", | |
| "The rain falls in sheets, washing away the bloodstains from the alleyways.", | |
| "She plugs into the matrix, seeking answers to questions that have haunted her for years." | |
| ] | |
| elif data_source == "uploaded file" and uploaded_file is not None: | |
| if uploaded_file.name.endswith(".txt"): | |
| data = [uploaded_file.read().decode("utf-8")] | |
| elif uploaded_file.name.endswith(".csv"): | |
| df = pd.read_csv(uploaded_file) | |
| data = df[df.columns[0]].astype(str).tolist() # Ensure all data is string | |
| else: | |
| data = ["Unsupported file format."] | |
| else: | |
| data = ["No file uploaded. Please upload a dataset."] | |
| dataset = prepare_dataset(data, tokenizer) | |
| return dataset | |
| # Train Model Function | |
| def train_model(model, train_dataset, tokenizer, epochs=3, batch_size=4, use_ga=False, ga_params=None): | |
| if not use_ga: | |
| training_args = TrainingArguments( | |
| output_dir="./results", | |
| overwrite_output_dir=True, | |
| num_train_epochs=epochs, | |
| per_device_train_batch_size=batch_size, | |
| save_steps=10_000, | |
| save_total_limit=2, | |
| logging_dir="./logs", | |
| logging_steps=1, | |
| logging_strategy='steps', | |
| report_to=None, # Disable default logging to WandB or other services | |
| ) | |
| data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False) | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| data_collator=data_collator, | |
| train_dataset=train_dataset, | |
| ) | |
| trainer.train() | |
| return trainer.state.log_history | |
| else: | |
| # GA training logic | |
| param_bounds = { | |
| 'learning_rate': (1e-5, 5e-5), | |
| 'epochs': (1, ga_params['max_epochs']), | |
| 'batch_size': [2, 4, 8, 16] | |
| } | |
| population = create_ga_population(ga_params['population_size'], param_bounds) | |
| best_individual = None | |
| best_fitness = float('inf') | |
| all_losses = [] | |
| for generation in range(ga_params['num_generations']): | |
| fitnesses = [] | |
| for idx, individual in enumerate(population): | |
| model_copy = GPT2LMHeadModel.from_pretrained('gpt2') | |
| training_args = TrainingArguments( | |
| output_dir=f"./results/ga_{generation}_{idx}", | |
| num_train_epochs=individual['epochs'], | |
| per_device_train_batch_size=individual['batch_size'], | |
| learning_rate=individual['learning_rate'], | |
| logging_steps=1, | |
| logging_strategy='steps', | |
| report_to=None, # Disable default logging to WandB or other services | |
| ) | |
| trainer = Trainer( | |
| model=model_copy, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| ) | |
| # Capture the training result | |
| train_result = trainer.train() | |
| # Safely retrieve the training loss | |
| fitness = train_result.metrics.get('train_loss', None) | |
| if fitness is None: | |
| # If 'train_loss' is not available, try to compute it from log history | |
| if 'loss' in trainer.state.log_history[-1]: | |
| fitness = trainer.state.log_history[-1]['loss'] | |
| else: | |
| fitness = float('inf') # Assign a large number if loss is not available | |
| fitnesses.append(fitness) | |
| all_losses.extend(trainer.state.log_history) | |
| if fitness < best_fitness: | |
| best_fitness = fitness | |
| best_individual = individual | |
| model.load_state_dict(model_copy.state_dict()) | |
| del model_copy | |
| torch.cuda.empty_cache() | |
| # GA operations | |
| parents = select_ga_parents(population, fitnesses, ga_params['num_parents']) | |
| offspring_size = ga_params['population_size'] - ga_params['num_parents'] | |
| offspring = ga_crossover(parents, offspring_size) | |
| offspring = ga_mutation(offspring, param_bounds, ga_params['mutation_rate']) | |
| population = parents + offspring | |
| return all_losses | |
| # GA-related functions | |
| def create_ga_population(size: int, param_bounds: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """Create initial population for genetic algorithm""" | |
| population = [] | |
| for _ in range(size): | |
| individual = { | |
| 'learning_rate': random.uniform(*param_bounds['learning_rate']), | |
| 'epochs': random.randint(*param_bounds['epochs']), | |
| 'batch_size': random.choice(param_bounds['batch_size']), | |
| } | |
| population.append(individual) | |
| return population | |
| def select_ga_parents(population: List[Dict[str, Any]], fitnesses: List[float], num_parents: int) -> List[Dict[str, Any]]: | |
| """Select best performing individuals as parents""" | |
| parents = [population[i] for i in np.argsort(fitnesses)[:num_parents]] | |
| return parents | |
| def ga_crossover(parents: List[Dict[str, Any]], offspring_size: int) -> List[Dict[str, Any]]: | |
| """Create offspring through crossover of parents""" | |
| offspring = [] | |
| for _ in range(offspring_size): | |
| parent1 = random.choice(parents) | |
| parent2 = random.choice(parents) | |
| child = { | |
| 'learning_rate': random.choice([parent1['learning_rate'], parent2['learning_rate']]), | |
| 'epochs': random.choice([parent1['epochs'], parent2['epochs']]), | |
| 'batch_size': random.choice([parent1['batch_size'], parent2['batch_size']]), | |
| } | |
| offspring.append(child) | |
| return offspring | |
| def ga_mutation(offspring: List[Dict[str, Any]], param_bounds: Dict[str, Any], mutation_rate: float = 0.1) -> List[Dict[str, Any]]: | |
| """Apply random mutations to offspring""" | |
| for individual in offspring: | |
| if random.random() < mutation_rate: | |
| individual['learning_rate'] = random.uniform(*param_bounds['learning_rate']) | |
| if random.random() < mutation_rate: | |
| individual['epochs'] = random.randint(*param_bounds['epochs']) | |
| if random.random() < mutation_rate: | |
| individual['batch_size'] = random.choice(param_bounds['batch_size']) | |
| return offspring | |
| # Main App Logic | |
| def main(): | |
| setup_cyberpunk_style() | |
| st.markdown('<h1 class="main-title">Neural Training Hub</h1>', unsafe_allow_html=True) | |
| # Sidebar Configuration with Additional Options | |
| with st.sidebar: | |
| st.markdown("### Configuration Panel") | |
| # Hugging Face API Token Input | |
| hf_token = st.text_input("Enter your Hugging Face Token", type="password") | |
| if hf_token: | |
| api = HfApi() | |
| api.set_access_token(hf_token) | |
| st.success("Hugging Face token added successfully!") | |
| # Training Parameters | |
| training_epochs = st.slider("Training Epochs", min_value=1, max_value=5, value=3) | |
| batch_size = st.slider("Batch Size", min_value=2, max_value=8, value=4) | |
| model_choice = st.selectbox("Model Selection", ("gpt2", "distilgpt2", "gpt2-medium")) | |
| # Dataset Source Selection | |
| data_source = st.selectbox("Data Source", ("demo", "uploaded file")) | |
| uploaded_file = st.file_uploader("Upload a text file", type=["txt", "csv"]) if data_source == "uploaded file" else None | |
| custom_learning_rate = st.number_input("Learning Rate", min_value=1e-6, max_value=5e-4, value=3e-5, step=1e-6, format="%.6f") | |
| # Advanced Settings Toggle | |
| advanced_toggle = st.checkbox("Advanced Training Settings") | |
| if advanced_toggle: | |
| warmup_steps = st.slider("Warmup Steps", min_value=0, max_value=500, value=100) | |
| weight_decay = st.slider("Weight Decay", min_value=0.0, max_value=0.1, step=0.01, value=0.01) | |
| else: | |
| warmup_steps = 100 | |
| weight_decay = 0.01 | |
| # Add training method selection | |
| training_method = st.selectbox("Training Method", ("Standard", "Genetic Algorithm")) | |
| if training_method == "Genetic Algorithm": | |
| st.markdown("### GA Parameters") | |
| ga_params = { | |
| 'population_size': st.slider("Population Size", min_value=4, max_value=10, value=6), | |
| 'num_generations': st.slider("Number of Generations", min_value=1, max_value=5, value=3), | |
| 'num_parents': st.slider("Number of Parents", min_value=2, max_value=4, value=2), | |
| 'mutation_rate': st.slider("Mutation Rate", min_value=0.0, max_value=1.0, value=0.1), | |
| 'max_epochs': training_epochs | |
| } | |
| else: | |
| ga_params = None | |
| # Initialize model and tokenizer | |
| if 'model' not in st.session_state: | |
| model, tokenizer = initialize_model(model_name=model_choice) | |
| st.session_state['model'] = model | |
| st.session_state['tokenizer'] = tokenizer | |
| st.session_state['model_name'] = model_choice | |
| else: | |
| if st.session_state.get('model_name') != model_choice: | |
| model, tokenizer = initialize_model(model_name=model_choice) | |
| st.session_state['model'] = model | |
| st.session_state['tokenizer'] = tokenizer | |
| st.session_state['model_name'] = model_choice | |
| else: | |
| model = st.session_state['model'] | |
| tokenizer = st.session_state['tokenizer'] | |
| # Load Dataset | |
| train_dataset = load_dataset(data_source, tokenizer, uploaded_file=uploaded_file) | |
| # Go Button to Start Training | |
| if st.button("Go"): | |
| st.markdown("### Model Training Progress") | |
| progress_bar = st.progress(0) | |
| status_text = st.empty() | |
| status_text.text("Training in progress...") | |
| # Train the model | |
| if training_method == "Standard": | |
| logs = train_model(model, train_dataset, tokenizer, training_epochs, batch_size) | |
| else: | |
| logs = train_model(model, train_dataset, tokenizer, training_epochs, batch_size, use_ga=True, ga_params=ga_params) | |
| # Update progress bar to 100% | |
| progress_bar.progress(100) | |
| status_text.text("Training complete!") | |
| # Store the model and logs in st.session_state | |
| st.session_state['model'] = model | |
| st.session_state['logs'] = logs | |
| # Plot the losses if available | |
| if 'logs' in st.session_state: | |
| logs = st.session_state['logs'] | |
| losses = [log['loss'] for log in logs if 'loss' in log] | |
| steps = list(range(len(losses))) | |
| if losses: | |
| # Plot the losses | |
| fig = go.Figure() | |
| fig.add_trace(go.Scatter(x=steps, y=losses, mode='lines+markers', name='Training Loss', line=dict(color='#00ff9d'))) | |
| fig.update_layout( | |
| title="Training Progress", | |
| xaxis_title="Training Steps", | |
| yaxis_title="Loss", | |
| template="plotly_dark", | |
| plot_bgcolor='rgba(0,0,0,0)', | |
| paper_bgcolor='rgba(0,0,0,0)', | |
| font=dict(color='#00ff9d') | |
| ) | |
| st.plotly_chart(fig, use_container_width=True) | |
| else: | |
| st.write("No loss data available to plot.") | |
| else: | |
| st.write("Train the model to see the loss plot.") | |
| # After training, you can use the model for inference | |
| st.markdown("### Model Inference") | |
| with st.form("inference_form"): | |
| user_input = st.text_input("Enter prompt for the model:") | |
| submitted = st.form_submit_button("Generate") | |
| if submitted: | |
| if 'model' in st.session_state: | |
| model = st.session_state['model'] | |
| tokenizer = st.session_state['tokenizer'] | |
| inputs = tokenizer(user_input, return_tensors="pt") | |
| outputs = model.generate(inputs['input_ids'], max_length=100, num_return_sequences=1) | |
| response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| st.write("Model output:", response) | |
| else: | |
| st.write("Please train the model first.") | |
| if __name__ == "__main__": | |
| main() |