import gradio as gr import torch import pandas as pd import os import json import logging import numpy as np from datetime import datetime from pathlib import Path from sklearn.metrics import accuracy_score, classification_report from sklearn.model_selection import train_test_split from huggingface_hub import login from transformers import ( AutoTokenizer, BertForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding, EarlyStoppingCallback, TrainerCallback ) from datasets import Dataset, DatasetDict # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global variables MODEL_PATH = "local-model" CATEGORIES = ['Online-Safety', 'BroadBand', 'TV-Radio'] idx_to_category = {0: 'Online-Safety', 1: 'BroadBand', 2: 'TV-Radio'} category_to_idx = {'Online-Safety': 0, 'BroadBand': 1, 'TV-Radio': 2} TOKEN = None TRAINING_LOGS = [] CURRENT_MODEL = None CURRENT_TOKENIZER = None def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2): """Load and prepare local CSV dataset for training""" try: if not os.path.exists(file_path): raise FileNotFoundError(f"Dataset file not found: {file_path}") # Load the CSV file df = pd.read_csv(file_path) # Verify required columns exist if text_column not in df.columns: available_cols = list(df.columns) raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}") if label_column not in df.columns: available_cols = list(df.columns) raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}") # Clean the data df = df.dropna(subset=[text_column, label_column]) df[text_column] = df[text_column].astype(str) # Handle different label formats if df[label_column].dtype == 'object': # If labels are text, convert to indices unique_labels = df[label_column].unique() if len(unique_labels) > len(CATEGORIES): raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}") # Try to map text labels to our categories label_mapping = {} for label in unique_labels: if label in category_to_idx: label_mapping[label] = category_to_idx[label] else: # Auto-assign if not found available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values()) if available_indices: label_mapping[label] = min(available_indices) else: raise ValueError(f"Cannot map label '{label}' to available categories") df['label_idx'] = df[label_column].map(label_mapping) else: # If labels are already numeric df['label_idx'] = df[label_column].astype(int) # Verify label indices are valid if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES): raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}") # Create train/validation split train_df, val_df = train_test_split( df, test_size=0.2, random_state=42, stratify=df['label_idx'] ) # Convert to Hugging Face datasets train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']]) val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']]) dataset_dict = DatasetDict({ 'train': train_dataset, 'validation': val_dataset }) return dataset_dict, text_column, 'label_idx' except Exception as e: raise Exception(f"Error loading dataset: {str(e)}") def preview_dataset(uploaded_file, text_column, label_column): """Preview a dataset file""" try: if uploaded_file is None: return "Please upload a dataset file first." # Get the file path from the uploaded file file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file # Check if file exists and is CSV if not os.path.exists(file_path): return "āŒ File not found. Please try uploading again." if not file_path.lower().endswith('.csv'): return "āŒ Please upload a CSV file (.csv extension required)." df = pd.read_csv(file_path) preview_info = [] preview_info.append(f"šŸ“Š **Dataset Preview: {os.path.basename(file_path)}**") preview_info.append(f"- **Total rows:** {len(df)}") preview_info.append(f"- **Columns:** {list(df.columns)}") preview_info.append("") if text_column in df.columns: preview_info.append(f"āœ… **Text column '{text_column}' found**") preview_info.append(f"- Sample text: {str(df[text_column].iloc[0])[:100]}...") else: preview_info.append(f"āŒ **Text column '{text_column}' not found**") return "\n".join(preview_info) if label_column in df.columns: preview_info.append(f"āœ… **Label column '{label_column}' found**") label_counts = df[label_column].value_counts() preview_info.append("- **Label distribution:**") for label, count in label_counts.items(): preview_info.append(f" - {label}: {count} ({count/len(df)*100:.1f}%)") else: preview_info.append(f"āŒ **Label column '{label_column}' not found**") return "\n".join(preview_info) return "\n".join(preview_info) except Exception as e: return f"āŒ Error previewing dataset: {str(e)}" def login_to_hf(token): """Login to Hugging Face""" global TOKEN TOKEN = token try: login(token) return "āœ… Successfully logged in to Hugging Face" except Exception as e: return f"āŒ Login failed: {str(e)}" def validate_hub_model_id(username, model_name): """Validate and construct Hub model ID""" if not username or not model_name: return None, "Please provide both username and model name" # Clean up the model name model_name = model_name.strip().lower().replace(" ", "-") model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_']) # Construct the full model ID hub_model_id = f"{username}/{model_name}" return hub_model_id, None def load_model(model_path): """Load a trained model and tokenizer""" global CURRENT_MODEL, CURRENT_TOKENIZER try: # Try loading from local path first if os.path.exists(model_path): CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path) CURRENT_MODEL = BertForSequenceClassification.from_pretrained( model_path, num_labels=len(CATEGORIES) ) return f"āœ… Model loaded from local path: {model_path}" # If local path doesn't exist, try loading from Hub try: CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path) CURRENT_MODEL = BertForSequenceClassification.from_pretrained( model_path, num_labels=len(CATEGORIES) ) return f"āœ… Model loaded from Hugging Face Hub: {model_path}" except Exception as hub_error: # If both local and hub loading fail, fall back to base model CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased") CURRENT_MODEL = BertForSequenceClassification.from_pretrained( "bert-base-uncased", num_labels=len(CATEGORIES) ) return f"āš ļø Failed to load specified model, using base BERT model instead. Error: {str(hub_error)}" except Exception as e: return f"āŒ Failed to load model: {str(e)}" def tokenize_function(examples, tokenizer, feature_column, max_length=512): """Tokenize the input text""" return tokenizer( examples[feature_column], truncation=True, padding=False, max_length=max_length ) def compute_metrics(eval_pred): """Compute metrics for evaluation""" predictions, labels = eval_pred predictions = np.argmax(predictions, axis=1) accuracy = accuracy_score(labels, predictions) report = classification_report(labels, predictions, output_dict=True, zero_division=0) return { 'accuracy': accuracy, 'f1_macro': report['macro avg']['f1-score'], 'f1_weighted': report['weighted avg']['f1-score'], 'precision_macro': report['macro avg']['precision'], 'recall_macro': report['macro avg']['recall'] } def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size, learning_rate, hf_token, push_to_hub, username, model_name): """Train the model using inline training (no subprocess)""" global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER TRAINING_LOGS = [] if hf_token: login_result = login_to_hf(hf_token) TRAINING_LOGS.append(login_result) yield "\n".join(TRAINING_LOGS) # Validate hub model ID if pushing to hub if push_to_hub: hub_model_id, error = validate_hub_model_id(username, model_name) if error: TRAINING_LOGS.append(f"āŒ {error}") yield "\n".join(TRAINING_LOGS) return else: hub_model_id = None # Validate uploaded file if uploaded_file is None: TRAINING_LOGS.append("āŒ Please upload a dataset file") yield "\n".join(TRAINING_LOGS) return dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file try: # Load and prepare dataset TRAINING_LOGS.append(f"šŸ“Š Loading dataset from uploaded file...") yield "\n".join(TRAINING_LOGS) dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset( dataset_file, text_column, label_column ) TRAINING_LOGS.append(f"āœ… Dataset loaded successfully!") TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}") TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}") yield "\n".join(TRAINING_LOGS) # Load model and tokenizer TRAINING_LOGS.append("šŸ¤– Loading BERT model and tokenizer...") yield "\n".join(TRAINING_LOGS) tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") model = BertForSequenceClassification.from_pretrained( "bert-base-uncased", num_labels=len(CATEGORIES) ) TRAINING_LOGS.append("āœ… Model and tokenizer loaded") yield "\n".join(TRAINING_LOGS) # Tokenize datasets TRAINING_LOGS.append("šŸ”¤ Tokenizing datasets...") yield "\n".join(TRAINING_LOGS) def tokenize_batch(examples): return tokenize_function(examples, tokenizer, final_text_col, 512) # Get columns to remove (keep only label column and tokenized features) columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col] tokenized_datasets = dataset_dict.map( tokenize_batch, batched=True, remove_columns=columns_to_remove ) # Rename label column to 'labels' (required by Trainer) tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels') TRAINING_LOGS.append("āœ… Tokenization completed") yield "\n".join(TRAINING_LOGS) # Set up training output_dir = Path(MODEL_PATH) output_dir.mkdir(parents=True, exist_ok=True) # Calculate steps total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs eval_steps = max(10, min(100, total_steps // 4)) save_steps = max(20, min(500, total_steps // 2)) logging_steps = max(5, min(50, total_steps // 10)) warmup_steps = min(500, total_steps // 10) TRAINING_LOGS.append(f"šŸ“ˆ Training configuration:") TRAINING_LOGS.append(f"- Total steps: {total_steps}") TRAINING_LOGS.append(f"- Eval steps: {eval_steps}") TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}") yield "\n".join(TRAINING_LOGS) training_args = TrainingArguments( output_dir=str(output_dir), num_train_epochs=num_epochs, per_device_train_batch_size=batch_size, per_device_eval_batch_size=batch_size, warmup_steps=warmup_steps, weight_decay=0.01, learning_rate=learning_rate, logging_dir=str(output_dir / "logs"), logging_steps=logging_steps, eval_strategy="steps", # Corrected back to eval_strategy eval_steps=eval_steps, save_steps=save_steps, save_total_limit=2, load_best_model_at_end=True, metric_for_best_model="eval_accuracy", greater_is_better=True, push_to_hub=push_to_hub, hub_model_id=hub_model_id if push_to_hub else None, report_to=None, dataloader_num_workers=0, fp16=torch.cuda.is_available(), seed=42, remove_unused_columns=False, ) # Data collator data_collator = DataCollatorWithPadding(tokenizer=tokenizer) # Corrected Callback Class class ProgressCallback(TrainerCallback): def __init__(self, logs_list, total_steps): self.logs = logs_list self.total_steps = total_steps def on_train_begin(self, args, state, control, **kwargs): self.logs.append("šŸš€ Starting training...") self.log_update() def on_step_end(self, args, state, control, **kwargs): if state.global_step % args.logging_steps == 0: self.logs.append(f"Step {state.global_step}/{self.total_steps}") self.log_update() def on_epoch_end(self, args, state, control, **kwargs): epoch = int(state.epoch) self.logs.append(f"āœ… Epoch {epoch} completed") self.log_update() def on_evaluate(self, args, state, control, logs=None, **kwargs): if logs: acc = logs.get('eval_accuracy', 0) loss = logs.get('eval_loss', 0) self.logs.append(f"šŸ“Š Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}") self.log_update() def log_update(self): # This is a custom helper to yield updates to the Gradio UI # The original code did this manually, but with TrainerCallback, # we can't do that. So we log to the list and rely on the UI # to refresh. For a real-time stream, this part would need to be # handled by Gradio's streaming feature, but this approach # is sufficient for the user's current setup. pass # Create trainer trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_datasets['train'], eval_dataset=tokenized_datasets['validation'], tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, callbacks=[EarlyStoppingCallback(early_stopping_patience=3), ProgressCallback(TRAINING_LOGS, total_steps)] ) # Train the model try: trainer.train() TRAINING_LOGS.append("āœ… Training completed successfully!") yield "\n".join(TRAINING_LOGS) except Exception as e: TRAINING_LOGS.append(f"āŒ Training failed: {str(e)}") yield "\n".join(TRAINING_LOGS) return # Save the model TRAINING_LOGS.append("šŸ’¾ Saving model...") yield "\n".join(TRAINING_LOGS) trainer.save_model() tokenizer.save_pretrained(output_dir) # Update global model and tokenizer CURRENT_MODEL = model CURRENT_TOKENIZER = tokenizer TRAINING_LOGS.append("āœ… Model saved successfully!") yield "\n".join(TRAINING_LOGS) # Final evaluation TRAINING_LOGS.append("šŸ“Š Running final evaluation...") yield "\n".join(TRAINING_LOGS) try: eval_results = trainer.evaluate() TRAINING_LOGS.append("šŸ“Š Final Results:") for key, value in eval_results.items(): if isinstance(value, float): TRAINING_LOGS.append(f" {key}: {value:.4f}") else: TRAINING_LOGS.append(f" {key}: {value}") # Save results with open(output_dir / "eval_results.json", "w") as f: json.dump(eval_results, f, indent=2) except Exception as e: TRAINING_LOGS.append(f"āš ļø Evaluation error: {str(e)}") yield "\n".join(TRAINING_LOGS) # Push to hub if requested if push_to_hub and hub_model_id: TRAINING_LOGS.append(f"šŸ¤— Pushing to Hugging Face Hub: {hub_model_id}") yield "\n".join(TRAINING_LOGS) try: trainer.push_to_hub() TRAINING_LOGS.append(f"āœ… Successfully pushed to {hub_model_id}") except Exception as e: TRAINING_LOGS.append(f"āŒ Push to Hub failed: {str(e)}") yield "\n".join(TRAINING_LOGS) TRAINING_LOGS.append("\n✨ Training completed! Your model is ready to use.") yield "\n".join(TRAINING_LOGS) except Exception as e: TRAINING_LOGS.append(f"āŒ Error during training: {str(e)}") yield "\n".join(TRAINING_LOGS) def predict_text(text, model_path): """Make a prediction on a single text input""" global CURRENT_MODEL, CURRENT_TOKENIZER # Load the model if it's not loaded or a different one is requested if CURRENT_MODEL is None or model_path != MODEL_PATH: load_result = load_model(model_path) if load_result.startswith("āŒ"): return load_result try: if not text.strip(): return "Please enter some text to classify." # Check if text was truncated original_tokens = CURRENT_TOKENIZER(text, truncation=False) was_truncated = len(original_tokens['input_ids']) > 512 # Tokenize input inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512) # Make prediction with torch.no_grad(): outputs = CURRENT_MODEL(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_class_id = predictions.argmax().item() confidence = predictions.max().item() # Get predicted category predicted_category = idx_to_category[predicted_class_id] # Format result truncation_warning = "\n\nāš ļø Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else "" result = [] result.append(f"**Complaint:** {text}") result.append(f"\n**Predicted Category:** {predicted_category}") result.append(f"**Confidence:** {confidence:.4f}") result.append("\n**All Class Probabilities:**") for i, category in enumerate(CATEGORIES): prob = predictions[0][i].item() result.append(f"- {category}: {prob:.4f}") result.append(truncation_warning) return "\n".join(result) except Exception as e: return f"āŒ Prediction error: {str(e)}" def predict_csv(csv_file, model_path): """Make predictions on a CSV file with complaints""" global CURRENT_MODEL, CURRENT_TOKENIZER # Load the model if needed if CURRENT_MODEL is None or model_path != MODEL_PATH: load_result = load_model(model_path) if load_result.startswith("āŒ"): return load_result, None try: # Read the CSV file if hasattr(csv_file, 'name'): df = pd.read_csv(csv_file.name) else: df = pd.read_csv(csv_file) if 'complaint' not in df.columns: return "āŒ CSV file must have a 'complaint' column", None results = [] predictions_list = [] truncated_count = 0 for i, row in enumerate(df.iterrows()): complaint = str(row[1]['complaint']) # Check for truncation original_tokens = CURRENT_TOKENIZER(complaint, truncation=False) was_truncated = len(original_tokens['input_ids']) > 512 if was_truncated: truncated_count += 1 # Predict inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512) with torch.no_grad(): outputs = CURRENT_MODEL(**inputs) predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) predicted_idx = predictions.argmax().item() confidence = predictions.max().item() predicted_category = idx_to_category[predicted_idx] predictions_list.append({ 'complaint': complaint, 'predicted_category': predicted_category, 'confidence': confidence, 'truncated': was_truncated }) truncation_mark = " āš ļø" if was_truncated else "" preview = complaint if len(complaint) <= 50 else complaint[:47] + "..." results.append(f"Complaint {i+1}{truncation_mark}: {preview}") results.append(f"Predicted: {predicted_category} (confidence: {confidence:.3f})\n") if i >= 19: results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)") break if truncated_count > 0: results.append(f"\nāš ļø {truncated_count} complaints were truncated to fit BERT's 512 token limit.") # Save full results to a CSV file results_df = pd.DataFrame(predictions_list) results_file = "prediction_results.csv" results_df.to_csv(results_file, index=False) results.append(f"\nšŸ’¾ Full results saved to {results_file}") return "\n".join(results), results_file except Exception as e: return f"āŒ CSV processing failed: {str(e)}", None def push_to_hub_after_training(model_path, username, model_name, token): """Push a trained model to Hugging Face Hub""" try: if not token: return "āŒ Please provide a Hugging Face token" hub_model_id, error = validate_hub_model_id(username, model_name) if error: return f"āŒ {error}" # Login and load model login(token) if not os.path.exists(model_path): return "āŒ No trained model found. Please train a model first." try: model = BertForSequenceClassification.from_pretrained(model_path) tokenizer = AutoTokenizer.from_pretrained(model_path) except Exception as e: return f"āŒ Failed to load model: {str(e)}" # Push to Hub try: model.push_to_hub(hub_model_id) tokenizer.push_to_hub(hub_model_id) return f"āœ… Successfully pushed model to {hub_model_id}" except Exception as e: return f"āŒ Failed to push to Hub: {str(e)}" except Exception as e: return f"āŒ Error: {str(e)}" def count_tokens(text): """Count tokens in input text""" global CURRENT_TOKENIZER if text is None: return "Enter text to see token count" # Attempt to load a default tokenizer if it's not set if CURRENT_TOKENIZER is None: try: CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased") logger.info("Fallback: Tokenizer loaded in count_tokens function.") except Exception as e: logger.error(f"Failed to load tokenizer in count_tokens fallback: {e}") return "āŒ Error: Tokenizer not loaded. Please load a model or check logs." # If tokenizer is still None after fallback, something is seriously wrong if CURRENT_TOKENIZER is None: return "āŒ Error: Tokenizer is still not available." tokens = CURRENT_TOKENIZER(text, truncation=False) count = len(tokens['input_ids']) if count > 512: return f"āš ļø **Token count: {count}/512** - Text will be truncated for BERT" else: return f"Token count: {count}/512" def get_available_datasets(): """Get list of available CSV files in the current directory""" available_files = [] for file in os.listdir("."): if file.endswith(".csv"): try: df = pd.read_csv(file) available_files.append(f"{file} ({len(df)} rows)") except: available_files.append(f"{file} (Error reading)") if not available_files: available_files = ["No CSV files found in current directory"] return available_files def display_available_datasets(): datasets = get_available_datasets() if datasets: return "**Available CSV files:**\n\n" + "\n".join([f"- {file}" for file in datasets]) else: return "No CSV files found in the current directory." # Initialize tokenizer on startup if CURRENT_TOKENIZER is None: try: CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased") print("āœ… Tokenizer initialized successfully") except Exception as e: print(f"āš ļø Warning: Could not initialize tokenizer globally: {e}") print("šŸš€ Launching BERT Complaint Classifier...") print("šŸ“ Available at: http://localhost:7860") # The entire Gradio UI definition must be within a single block with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app: gr.Markdown("# BERT Complaint Classifier šŸ—£ļøšŸ¤–") gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.") with gr.Tab("Fine-tune Model"): gr.Markdown("## šŸ‹ļø Fine-tune a New Model") with gr.Column(variant="panel"): gr.Markdown("### šŸ› ļø Training Configuration") with gr.Row(): uploaded_file = gr.File(label="Upload Training CSV File", type="filepath", file_types=[".csv"]) preview_btn = gr.Button("Preview Dataset") preview_output = gr.Markdown("Dataset info will appear here") with gr.Row(): text_column_input = gr.Textbox(label="Text Column Name", value="complaint") label_column_input = gr.Textbox(label="Label Column Name", value="category") gr.Markdown("---") with gr.Row(): num_epochs_slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Epochs") batch_size_slider = gr.Slider(minimum=4, maximum=32, step=4, value=8, label="Batch Size") learning_rate_slider = gr.Slider(minimum=1e-6, maximum=1e-4, step=1e-6, value=2e-5, label="Learning Rate") gr.Markdown("---") gr.Markdown("### ā˜ļø Hugging Face Hub (Optional)") with gr.Row(): push_to_hub_checkbox = gr.Checkbox(label="Push to Hugging Face Hub") hf_token_input = gr.Textbox(label="Hugging Face Token", type="password") with gr.Row(): hf_username_input = gr.Textbox(label="Hugging Face Username") hf_model_name_input = gr.Textbox(label="Model Name (for Hub)", value="bert-complaint-classifier") train_btn = gr.Button("šŸš€ Start Training", variant="primary") gr.Markdown("---") training_log_output = gr.Textbox(label="Training Logs", lines=20, max_lines=20, interactive=False) with gr.Tab("Predict"): gr.Markdown("## šŸ”® Make Predictions") gr.Markdown("Choose a method to classify complaints.") with gr.Tab("Predict Single Text"): with gr.Column(variant="panel"): gr.Markdown("### Classify a Single Complaint") model_path_input = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID") load_model_btn = gr.Button("Load Model") model_status = gr.Textbox(label="Model Status", interactive=False) gr.Markdown("---") text_input = gr.Textbox( label="Enter complaint text", lines=3, placeholder="Type your complaint here..." ) token_counter = gr.Textbox(label="Token Count", interactive=False, value="Enter text to see token count") predict_btn = gr.Button("šŸ”® Predict Category", variant="primary") prediction_output = gr.Markdown("Prediction results will appear here") with gr.Tab("Predict CSV File"): with gr.Column(variant="panel"): gr.Markdown("### Classify Multiple Complaints from CSV") gr.Markdown("Upload a CSV file with a 'complaint' column to classify multiple complaints at once.") csv_model_path = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID") csv_load_btn = gr.Button("Load Model") csv_model_status = gr.Textbox(label="Model Status", interactive=False) gr.Markdown("---") csv_file_input = gr.File(label="Upload CSV File", type="filepath", file_types=[".csv"]) csv_predict_btn = gr.Button("šŸ”® Predict All", variant="primary") csv_prediction_output = gr.Markdown("CSV prediction results will appear here") csv_download = gr.File(label="Download Results", interactive=False) with gr.Tab("Push to Hub"): gr.Markdown("## šŸ¤— Push Trained Model to Hugging Face Hub") gr.Markdown("Upload your locally trained model to the Hugging Face Hub for sharing.") with gr.Column(variant="panel"): hub_model_path = gr.Textbox(label="Local Model Path", value=MODEL_PATH) hub_username = gr.Textbox(label="Hugging Face Username") hub_model_name = gr.Textbox(label="Model Name", value="bert-complaint-classifier") hub_token = gr.Textbox(label="Hugging Face Token", type="password") push_hub_btn = gr.Button("šŸš€ Push to Hub", variant="primary") push_hub_output = gr.Markdown("Push results will appear here") with gr.Tab("Dataset Info"): gr.Markdown("## šŸ“Š Dataset Information") gr.Markdown("View information about available datasets and model categories.") with gr.Column(variant="panel"): gr.Markdown("### šŸŽÆ Model Categories") categories_info = gr.Markdown(f"**Available Categories:**\n\n" + "\n".join([f"- **{cat}** (index: {idx})" for idx, cat in idx_to_category.items()])) gr.Markdown("---") gr.Markdown("### šŸ“ Available Datasets") datasets_btn = gr.Button("šŸ” Scan for CSV Files") datasets_info = gr.Markdown("Click 'Scan for CSV Files' to see available datasets") gr.Markdown("---") gr.Markdown("### šŸ’” Tips") gr.Markdown(""" **Dataset Format:** - CSV file with at least two columns - One column for text (complaints) - One column for labels/categories - Labels can be text (will be auto-mapped) or numeric indices (0, 1, 2) **Training Tips:** - Start with 3 epochs and adjust based on results - Use batch size 8-16 for most datasets - Learning rate 2e-5 works well for BERT fine-tuning - Enable early stopping to prevent overfitting **Token Limits:** - BERT has a 512 token limit - Long texts will be automatically truncated - Monitor the token counter when entering text """) # Connect functions to UI components preview_btn.click( preview_dataset, inputs=[uploaded_file, text_column_input, label_column_input], outputs=preview_output ) train_btn.click( train_model_inline, inputs=[ uploaded_file, text_column_input, label_column_input, num_epochs_slider, batch_size_slider, learning_rate_slider, hf_token_input, push_to_hub_checkbox, hf_username_input, hf_model_name_input, ], outputs=training_log_output, ) load_model_btn.click( load_model, inputs=model_path_input, outputs=model_status ) predict_btn.click( predict_text, inputs=[text_input, model_path_input], outputs=prediction_output ) text_input.change( count_tokens, inputs=text_input, outputs=token_counter ) csv_load_btn.click( load_model, inputs=csv_model_path, outputs=csv_model_status ) csv_predict_btn.click( predict_csv, inputs=[csv_file_input, csv_model_path], outputs=[csv_prediction_output, csv_download] ) push_hub_btn.click( push_to_hub_after_training, inputs=[hub_model_path, hub_username, hub_model_name, hub_token], outputs=push_hub_output ) datasets_btn.click( display_available_datasets, outputs=datasets_info ) # Run a check for available datasets on app load app.load(display_available_datasets, outputs=datasets_info) # Launch the Gradio app if __name__ == "__main__": app.launch( server_name="0.0.0.0", server_port=7860, share=False, show_error=True )