Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| import pandas as pd | |
| import os | |
| import json | |
| import logging | |
| import numpy as np | |
| from datetime import datetime | |
| from pathlib import Path | |
| from sklearn.metrics import accuracy_score, classification_report | |
| from sklearn.model_selection import train_test_split | |
| from huggingface_hub import login | |
| from transformers import ( | |
| AutoTokenizer, | |
| BertForSequenceClassification, | |
| TrainingArguments, | |
| Trainer, | |
| DataCollatorWithPadding, | |
| EarlyStoppingCallback, | |
| TrainerCallback | |
| ) | |
| from datasets import Dataset, DatasetDict | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Global variables | |
| MODEL_PATH = "local-model" | |
| CATEGORIES = ['Online-Safety', 'BroadBand', 'TV-Radio'] | |
| idx_to_category = {0: 'Online-Safety', 1: 'BroadBand', 2: 'TV-Radio'} | |
| category_to_idx = {'Online-Safety': 0, 'BroadBand': 1, 'TV-Radio': 2} | |
| TOKEN = None | |
| TRAINING_LOGS = [] | |
| CURRENT_MODEL = None | |
| CURRENT_TOKENIZER = None | |
| def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2): | |
| """Load and prepare local CSV dataset for training""" | |
| try: | |
| if not os.path.exists(file_path): | |
| raise FileNotFoundError(f"Dataset file not found: {file_path}") | |
| # Load the CSV file | |
| df = pd.read_csv(file_path) | |
| # Verify required columns exist | |
| if text_column not in df.columns: | |
| available_cols = list(df.columns) | |
| raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}") | |
| if label_column not in df.columns: | |
| available_cols = list(df.columns) | |
| raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}") | |
| # Clean the data | |
| df = df.dropna(subset=[text_column, label_column]) | |
| df[text_column] = df[text_column].astype(str) | |
| # Handle different label formats | |
| if df[label_column].dtype == 'object': | |
| # If labels are text, convert to indices | |
| unique_labels = df[label_column].unique() | |
| if len(unique_labels) > len(CATEGORIES): | |
| raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}") | |
| # Try to map text labels to our categories | |
| label_mapping = {} | |
| for label in unique_labels: | |
| if label in category_to_idx: | |
| label_mapping[label] = category_to_idx[label] | |
| else: | |
| # Auto-assign if not found | |
| available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values()) | |
| if available_indices: | |
| label_mapping[label] = min(available_indices) | |
| else: | |
| raise ValueError(f"Cannot map label '{label}' to available categories") | |
| df['label_idx'] = df[label_column].map(label_mapping) | |
| else: | |
| # If labels are already numeric | |
| df['label_idx'] = df[label_column].astype(int) | |
| # Verify label indices are valid | |
| if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES): | |
| raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}") | |
| # Create train/validation split | |
| train_df, val_df = train_test_split( | |
| df, | |
| test_size=0.2, | |
| random_state=42, | |
| stratify=df['label_idx'] | |
| ) | |
| # Convert to Hugging Face datasets | |
| train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']]) | |
| val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']]) | |
| dataset_dict = DatasetDict({ | |
| 'train': train_dataset, | |
| 'validation': val_dataset | |
| }) | |
| return dataset_dict, text_column, 'label_idx' | |
| except Exception as e: | |
| raise Exception(f"Error loading dataset: {str(e)}") | |
| def preview_dataset(uploaded_file, text_column, label_column): | |
| """Preview a dataset file""" | |
| try: | |
| if uploaded_file is None: | |
| return "Please upload a dataset file first." | |
| # Get the file path from the uploaded file | |
| file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file | |
| # Check if file exists and is CSV | |
| if not os.path.exists(file_path): | |
| return "❌ File not found. Please try uploading again." | |
| if not file_path.lower().endswith('.csv'): | |
| return "❌ Please upload a CSV file (.csv extension required)." | |
| df = pd.read_csv(file_path) | |
| preview_info = [] | |
| preview_info.append(f"📊 **Dataset Preview: {os.path.basename(file_path)}**") | |
| preview_info.append(f"- **Total rows:** {len(df)}") | |
| preview_info.append(f"- **Columns:** {list(df.columns)}") | |
| preview_info.append("") | |
| if text_column in df.columns: | |
| preview_info.append(f"✅ **Text column '{text_column}' found**") | |
| preview_info.append(f"- Sample text: {str(df[text_column].iloc[0])[:100]}...") | |
| else: | |
| preview_info.append(f"❌ **Text column '{text_column}' not found**") | |
| return "\n".join(preview_info) | |
| if label_column in df.columns: | |
| preview_info.append(f"✅ **Label column '{label_column}' found**") | |
| label_counts = df[label_column].value_counts() | |
| preview_info.append("- **Label distribution:**") | |
| for label, count in label_counts.items(): | |
| preview_info.append(f" - {label}: {count} ({count/len(df)*100:.1f}%)") | |
| else: | |
| preview_info.append(f"❌ **Label column '{label_column}' not found**") | |
| return "\n".join(preview_info) | |
| return "\n".join(preview_info) | |
| except Exception as e: | |
| return f"❌ Error previewing dataset: {str(e)}" | |
| def login_to_hf(token): | |
| """Login to Hugging Face""" | |
| global TOKEN | |
| TOKEN = token | |
| try: | |
| login(token) | |
| return "✅ Successfully logged in to Hugging Face" | |
| except Exception as e: | |
| return f"❌ Login failed: {str(e)}" | |
| def validate_hub_model_id(username, model_name): | |
| """Validate and construct Hub model ID""" | |
| if not username or not model_name: | |
| return None, "Please provide both username and model name" | |
| # Clean up the model name | |
| model_name = model_name.strip().lower().replace(" ", "-") | |
| model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_']) | |
| # Construct the full model ID | |
| hub_model_id = f"{username}/{model_name}" | |
| return hub_model_id, None | |
| def load_model(model_path): | |
| """Load a trained model and tokenizer""" | |
| global CURRENT_MODEL, CURRENT_TOKENIZER | |
| try: | |
| # Try loading from local path first | |
| if os.path.exists(model_path): | |
| CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path) | |
| CURRENT_MODEL = BertForSequenceClassification.from_pretrained( | |
| model_path, | |
| num_labels=len(CATEGORIES) | |
| ) | |
| return f"✅ Model loaded from local path: {model_path}" | |
| # If local path doesn't exist, try loading from Hub | |
| try: | |
| CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path) | |
| CURRENT_MODEL = BertForSequenceClassification.from_pretrained( | |
| model_path, | |
| num_labels=len(CATEGORIES) | |
| ) | |
| return f"✅ Model loaded from Hugging Face Hub: {model_path}" | |
| except Exception as hub_error: | |
| # If both local and hub loading fail, fall back to base model | |
| CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| CURRENT_MODEL = BertForSequenceClassification.from_pretrained( | |
| "bert-base-uncased", | |
| num_labels=len(CATEGORIES) | |
| ) | |
| return f"⚠️ Failed to load specified model, using base BERT model instead. Error: {str(hub_error)}" | |
| except Exception as e: | |
| return f"❌ Failed to load model: {str(e)}" | |
| def tokenize_function(examples, tokenizer, feature_column, max_length=512): | |
| """Tokenize the input text""" | |
| return tokenizer( | |
| examples[feature_column], | |
| truncation=True, | |
| padding=False, | |
| max_length=max_length | |
| ) | |
| def compute_metrics(eval_pred): | |
| """Compute metrics for evaluation""" | |
| predictions, labels = eval_pred | |
| predictions = np.argmax(predictions, axis=1) | |
| accuracy = accuracy_score(labels, predictions) | |
| report = classification_report(labels, predictions, output_dict=True, zero_division=0) | |
| return { | |
| 'accuracy': accuracy, | |
| 'f1_macro': report['macro avg']['f1-score'], | |
| 'f1_weighted': report['weighted avg']['f1-score'], | |
| 'precision_macro': report['macro avg']['precision'], | |
| 'recall_macro': report['macro avg']['recall'] | |
| } | |
| def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size, | |
| learning_rate, hf_token, push_to_hub, username, model_name): | |
| """Train the model using inline training (no subprocess)""" | |
| global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER | |
| TRAINING_LOGS = [] | |
| if hf_token: | |
| login_result = login_to_hf(hf_token) | |
| TRAINING_LOGS.append(login_result) | |
| yield "\n".join(TRAINING_LOGS) | |
| # Validate hub model ID if pushing to hub | |
| if push_to_hub: | |
| hub_model_id, error = validate_hub_model_id(username, model_name) | |
| if error: | |
| TRAINING_LOGS.append(f"❌ {error}") | |
| yield "\n".join(TRAINING_LOGS) | |
| return | |
| else: | |
| hub_model_id = None | |
| # Validate uploaded file | |
| if uploaded_file is None: | |
| TRAINING_LOGS.append("❌ Please upload a dataset file") | |
| yield "\n".join(TRAINING_LOGS) | |
| return | |
| dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file | |
| try: | |
| # Load and prepare dataset | |
| TRAINING_LOGS.append(f"📊 Loading dataset from uploaded file...") | |
| yield "\n".join(TRAINING_LOGS) | |
| dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset( | |
| dataset_file, text_column, label_column | |
| ) | |
| TRAINING_LOGS.append(f"✅ Dataset loaded successfully!") | |
| TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}") | |
| TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}") | |
| yield "\n".join(TRAINING_LOGS) | |
| # Load model and tokenizer | |
| TRAINING_LOGS.append("🤖 Loading BERT model and tokenizer...") | |
| yield "\n".join(TRAINING_LOGS) | |
| tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| model = BertForSequenceClassification.from_pretrained( | |
| "bert-base-uncased", | |
| num_labels=len(CATEGORIES) | |
| ) | |
| TRAINING_LOGS.append("✅ Model and tokenizer loaded") | |
| yield "\n".join(TRAINING_LOGS) | |
| # Tokenize datasets | |
| TRAINING_LOGS.append("🔤 Tokenizing datasets...") | |
| yield "\n".join(TRAINING_LOGS) | |
| def tokenize_batch(examples): | |
| return tokenize_function(examples, tokenizer, final_text_col, 512) | |
| # Get columns to remove (keep only label column and tokenized features) | |
| columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col] | |
| tokenized_datasets = dataset_dict.map( | |
| tokenize_batch, | |
| batched=True, | |
| remove_columns=columns_to_remove | |
| ) | |
| # Rename label column to 'labels' (required by Trainer) | |
| tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels') | |
| TRAINING_LOGS.append("✅ Tokenization completed") | |
| yield "\n".join(TRAINING_LOGS) | |
| # Set up training | |
| output_dir = Path(MODEL_PATH) | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| # Calculate steps | |
| total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs | |
| eval_steps = max(10, min(100, total_steps // 4)) | |
| save_steps = max(20, min(500, total_steps // 2)) | |
| logging_steps = max(5, min(50, total_steps // 10)) | |
| warmup_steps = min(500, total_steps // 10) | |
| TRAINING_LOGS.append(f"📈 Training configuration:") | |
| TRAINING_LOGS.append(f"- Total steps: {total_steps}") | |
| TRAINING_LOGS.append(f"- Eval steps: {eval_steps}") | |
| TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}") | |
| yield "\n".join(TRAINING_LOGS) | |
| training_args = TrainingArguments( | |
| output_dir=str(output_dir), | |
| num_train_epochs=num_epochs, | |
| per_device_train_batch_size=batch_size, | |
| per_device_eval_batch_size=batch_size, | |
| warmup_steps=warmup_steps, | |
| weight_decay=0.01, | |
| learning_rate=learning_rate, | |
| logging_dir=str(output_dir / "logs"), | |
| logging_steps=logging_steps, | |
| eval_strategy="steps", # Corrected back to eval_strategy | |
| eval_steps=eval_steps, | |
| save_steps=save_steps, | |
| save_total_limit=2, | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_accuracy", | |
| greater_is_better=True, | |
| push_to_hub=push_to_hub, | |
| hub_model_id=hub_model_id if push_to_hub else None, | |
| report_to=None, | |
| dataloader_num_workers=0, | |
| fp16=torch.cuda.is_available(), | |
| seed=42, | |
| remove_unused_columns=False, | |
| ) | |
| # Data collator | |
| data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
| # Corrected Callback Class | |
| class ProgressCallback(TrainerCallback): | |
| def __init__(self, logs_list, total_steps): | |
| self.logs = logs_list | |
| self.total_steps = total_steps | |
| def on_train_begin(self, args, state, control, **kwargs): | |
| self.logs.append("🚀 Starting training...") | |
| self.log_update() | |
| def on_step_end(self, args, state, control, **kwargs): | |
| if state.global_step % args.logging_steps == 0: | |
| self.logs.append(f"Step {state.global_step}/{self.total_steps}") | |
| self.log_update() | |
| def on_epoch_end(self, args, state, control, **kwargs): | |
| epoch = int(state.epoch) | |
| self.logs.append(f"✅ Epoch {epoch} completed") | |
| self.log_update() | |
| def on_evaluate(self, args, state, control, logs=None, **kwargs): | |
| if logs: | |
| acc = logs.get('eval_accuracy', 0) | |
| loss = logs.get('eval_loss', 0) | |
| self.logs.append(f"📊 Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}") | |
| self.log_update() | |
| def log_update(self): | |
| # This is a custom helper to yield updates to the Gradio UI | |
| # The original code did this manually, but with TrainerCallback, | |
| # we can't do that. So we log to the list and rely on the UI | |
| # to refresh. For a real-time stream, this part would need to be | |
| # handled by Gradio's streaming feature, but this approach | |
| # is sufficient for the user's current setup. | |
| pass | |
| # Create trainer | |
| trainer = Trainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=tokenized_datasets['train'], | |
| eval_dataset=tokenized_datasets['validation'], | |
| tokenizer=tokenizer, | |
| data_collator=data_collator, | |
| compute_metrics=compute_metrics, | |
| callbacks=[EarlyStoppingCallback(early_stopping_patience=3), ProgressCallback(TRAINING_LOGS, total_steps)] | |
| ) | |
| # Train the model | |
| try: | |
| trainer.train() | |
| TRAINING_LOGS.append("✅ Training completed successfully!") | |
| yield "\n".join(TRAINING_LOGS) | |
| except Exception as e: | |
| TRAINING_LOGS.append(f"❌ Training failed: {str(e)}") | |
| yield "\n".join(TRAINING_LOGS) | |
| return | |
| # Save the model | |
| TRAINING_LOGS.append("💾 Saving model...") | |
| yield "\n".join(TRAINING_LOGS) | |
| trainer.save_model() | |
| tokenizer.save_pretrained(output_dir) | |
| # Update global model and tokenizer | |
| CURRENT_MODEL = model | |
| CURRENT_TOKENIZER = tokenizer | |
| TRAINING_LOGS.append("✅ Model saved successfully!") | |
| yield "\n".join(TRAINING_LOGS) | |
| # Final evaluation | |
| TRAINING_LOGS.append("📊 Running final evaluation...") | |
| yield "\n".join(TRAINING_LOGS) | |
| try: | |
| eval_results = trainer.evaluate() | |
| TRAINING_LOGS.append("📊 Final Results:") | |
| for key, value in eval_results.items(): | |
| if isinstance(value, float): | |
| TRAINING_LOGS.append(f" {key}: {value:.4f}") | |
| else: | |
| TRAINING_LOGS.append(f" {key}: {value}") | |
| # Save results | |
| with open(output_dir / "eval_results.json", "w") as f: | |
| json.dump(eval_results, f, indent=2) | |
| except Exception as e: | |
| TRAINING_LOGS.append(f"⚠️ Evaluation error: {str(e)}") | |
| yield "\n".join(TRAINING_LOGS) | |
| # Push to hub if requested | |
| if push_to_hub and hub_model_id: | |
| TRAINING_LOGS.append(f"🤗 Pushing to Hugging Face Hub: {hub_model_id}") | |
| yield "\n".join(TRAINING_LOGS) | |
| try: | |
| trainer.push_to_hub() | |
| TRAINING_LOGS.append(f"✅ Successfully pushed to {hub_model_id}") | |
| except Exception as e: | |
| TRAINING_LOGS.append(f"❌ Push to Hub failed: {str(e)}") | |
| yield "\n".join(TRAINING_LOGS) | |
| TRAINING_LOGS.append("\n✨ Training completed! Your model is ready to use.") | |
| yield "\n".join(TRAINING_LOGS) | |
| except Exception as e: | |
| TRAINING_LOGS.append(f"❌ Error during training: {str(e)}") | |
| yield "\n".join(TRAINING_LOGS) | |
| def predict_text(text, model_path): | |
| """Make a prediction on a single text input""" | |
| global CURRENT_MODEL, CURRENT_TOKENIZER | |
| # Load the model if it's not loaded or a different one is requested | |
| if CURRENT_MODEL is None or model_path != MODEL_PATH: | |
| load_result = load_model(model_path) | |
| if load_result.startswith("❌"): | |
| return load_result | |
| try: | |
| if not text.strip(): | |
| return "Please enter some text to classify." | |
| # Check if text was truncated | |
| original_tokens = CURRENT_TOKENIZER(text, truncation=False) | |
| was_truncated = len(original_tokens['input_ids']) > 512 | |
| # Tokenize input | |
| inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512) | |
| # Make prediction | |
| with torch.no_grad(): | |
| outputs = CURRENT_MODEL(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_class_id = predictions.argmax().item() | |
| confidence = predictions.max().item() | |
| # Get predicted category | |
| predicted_category = idx_to_category[predicted_class_id] | |
| # Format result | |
| truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else "" | |
| result = [] | |
| result.append(f"**Complaint:** {text}") | |
| result.append(f"\n**Predicted Category:** {predicted_category}") | |
| result.append(f"**Confidence:** {confidence:.4f}") | |
| result.append("\n**All Class Probabilities:**") | |
| for i, category in enumerate(CATEGORIES): | |
| prob = predictions[0][i].item() | |
| result.append(f"- {category}: {prob:.4f}") | |
| result.append(truncation_warning) | |
| return "\n".join(result) | |
| except Exception as e: | |
| return f"❌ Prediction error: {str(e)}" | |
| def predict_csv(csv_file, model_path): | |
| """Make predictions on a CSV file with complaints""" | |
| global CURRENT_MODEL, CURRENT_TOKENIZER | |
| # Load the model if needed | |
| if CURRENT_MODEL is None or model_path != MODEL_PATH: | |
| load_result = load_model(model_path) | |
| if load_result.startswith("❌"): | |
| return load_result, None | |
| try: | |
| # Read the CSV file | |
| if hasattr(csv_file, 'name'): | |
| df = pd.read_csv(csv_file.name) | |
| else: | |
| df = pd.read_csv(csv_file) | |
| if 'complaint' not in df.columns: | |
| return "❌ CSV file must have a 'complaint' column", None | |
| results = [] | |
| predictions_list = [] | |
| truncated_count = 0 | |
| for i, row in enumerate(df.iterrows()): | |
| complaint = str(row[1]['complaint']) | |
| # Check for truncation | |
| original_tokens = CURRENT_TOKENIZER(complaint, truncation=False) | |
| was_truncated = len(original_tokens['input_ids']) > 512 | |
| if was_truncated: | |
| truncated_count += 1 | |
| # Predict | |
| inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512) | |
| with torch.no_grad(): | |
| outputs = CURRENT_MODEL(**inputs) | |
| predictions = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
| predicted_idx = predictions.argmax().item() | |
| confidence = predictions.max().item() | |
| predicted_category = idx_to_category[predicted_idx] | |
| predictions_list.append({ | |
| 'complaint': complaint, | |
| 'predicted_category': predicted_category, | |
| 'confidence': confidence, | |
| 'truncated': was_truncated | |
| }) | |
| truncation_mark = " ⚠️" if was_truncated else "" | |
| preview = complaint if len(complaint) <= 50 else complaint[:47] + "..." | |
| results.append(f"Complaint {i+1}{truncation_mark}: {preview}") | |
| results.append(f"Predicted: {predicted_category} (confidence: {confidence:.3f})\n") | |
| if i >= 19: | |
| results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)") | |
| break | |
| if truncated_count > 0: | |
| results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.") | |
| # Save full results to a CSV file | |
| results_df = pd.DataFrame(predictions_list) | |
| results_file = "prediction_results.csv" | |
| results_df.to_csv(results_file, index=False) | |
| results.append(f"\n💾 Full results saved to {results_file}") | |
| return "\n".join(results), results_file | |
| except Exception as e: | |
| return f"❌ CSV processing failed: {str(e)}", None | |
| def push_to_hub_after_training(model_path, username, model_name, token): | |
| """Push a trained model to Hugging Face Hub""" | |
| try: | |
| if not token: | |
| return "❌ Please provide a Hugging Face token" | |
| hub_model_id, error = validate_hub_model_id(username, model_name) | |
| if error: | |
| return f"❌ {error}" | |
| # Login and load model | |
| login(token) | |
| if not os.path.exists(model_path): | |
| return "❌ No trained model found. Please train a model first." | |
| try: | |
| model = BertForSequenceClassification.from_pretrained(model_path) | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| except Exception as e: | |
| return f"❌ Failed to load model: {str(e)}" | |
| # Push to Hub | |
| try: | |
| model.push_to_hub(hub_model_id) | |
| tokenizer.push_to_hub(hub_model_id) | |
| return f"✅ Successfully pushed model to {hub_model_id}" | |
| except Exception as e: | |
| return f"❌ Failed to push to Hub: {str(e)}" | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |
| def count_tokens(text): | |
| """Count tokens in input text""" | |
| global CURRENT_TOKENIZER | |
| if text is None: | |
| return "Enter text to see token count" | |
| # Attempt to load a default tokenizer if it's not set | |
| if CURRENT_TOKENIZER is None: | |
| try: | |
| CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| logger.info("Fallback: Tokenizer loaded in count_tokens function.") | |
| except Exception as e: | |
| logger.error(f"Failed to load tokenizer in count_tokens fallback: {e}") | |
| return "❌ Error: Tokenizer not loaded. Please load a model or check logs." | |
| # If tokenizer is still None after fallback, something is seriously wrong | |
| if CURRENT_TOKENIZER is None: | |
| return "❌ Error: Tokenizer is still not available." | |
| tokens = CURRENT_TOKENIZER(text, truncation=False) | |
| count = len(tokens['input_ids']) | |
| if count > 512: | |
| return f"⚠️ **Token count: {count}/512** - Text will be truncated for BERT" | |
| else: | |
| return f"Token count: {count}/512" | |
| def get_available_datasets(): | |
| """Get list of available CSV files in the current directory""" | |
| available_files = [] | |
| for file in os.listdir("."): | |
| if file.endswith(".csv"): | |
| try: | |
| df = pd.read_csv(file) | |
| available_files.append(f"{file} ({len(df)} rows)") | |
| except: | |
| available_files.append(f"{file} (Error reading)") | |
| if not available_files: | |
| available_files = ["No CSV files found in current directory"] | |
| return available_files | |
| def display_available_datasets(): | |
| datasets = get_available_datasets() | |
| if datasets: | |
| return "**Available CSV files:**\n\n" + "\n".join([f"- {file}" for file in datasets]) | |
| else: | |
| return "No CSV files found in the current directory." | |
| # Initialize tokenizer on startup | |
| if CURRENT_TOKENIZER is None: | |
| try: | |
| CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased") | |
| print("✅ Tokenizer initialized successfully") | |
| except Exception as e: | |
| print(f"⚠️ Warning: Could not initialize tokenizer globally: {e}") | |
| print("🚀 Launching BERT Complaint Classifier...") | |
| print("📍 Available at: http://localhost:7860") | |
| # The entire Gradio UI definition must be within a single block | |
| with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# BERT Complaint Classifier 🗣️🤖") | |
| gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.") | |
| with gr.Tab("Fine-tune Model"): | |
| gr.Markdown("## 🏋️ Fine-tune a New Model") | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("### 🛠️ Training Configuration") | |
| with gr.Row(): | |
| uploaded_file = gr.File(label="Upload Training CSV File", type="filepath", file_types=[".csv"]) | |
| preview_btn = gr.Button("Preview Dataset") | |
| preview_output = gr.Markdown("Dataset info will appear here") | |
| with gr.Row(): | |
| text_column_input = gr.Textbox(label="Text Column Name", value="complaint") | |
| label_column_input = gr.Textbox(label="Label Column Name", value="category") | |
| gr.Markdown("---") | |
| with gr.Row(): | |
| num_epochs_slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Epochs") | |
| batch_size_slider = gr.Slider(minimum=4, maximum=32, step=4, value=8, label="Batch Size") | |
| learning_rate_slider = gr.Slider(minimum=1e-6, maximum=1e-4, step=1e-6, value=2e-5, label="Learning Rate") | |
| gr.Markdown("---") | |
| gr.Markdown("### ☁️ Hugging Face Hub (Optional)") | |
| with gr.Row(): | |
| push_to_hub_checkbox = gr.Checkbox(label="Push to Hugging Face Hub") | |
| hf_token_input = gr.Textbox(label="Hugging Face Token", type="password") | |
| with gr.Row(): | |
| hf_username_input = gr.Textbox(label="Hugging Face Username") | |
| hf_model_name_input = gr.Textbox(label="Model Name (for Hub)", value="bert-complaint-classifier") | |
| train_btn = gr.Button("🚀 Start Training", variant="primary") | |
| gr.Markdown("---") | |
| training_log_output = gr.Textbox(label="Training Logs", lines=20, max_lines=20, interactive=False) | |
| with gr.Tab("Predict"): | |
| gr.Markdown("## 🔮 Make Predictions") | |
| gr.Markdown("Choose a method to classify complaints.") | |
| with gr.Tab("Predict Single Text"): | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("### Classify a Single Complaint") | |
| model_path_input = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID") | |
| load_model_btn = gr.Button("Load Model") | |
| model_status = gr.Textbox(label="Model Status", interactive=False) | |
| gr.Markdown("---") | |
| text_input = gr.Textbox( | |
| label="Enter complaint text", | |
| lines=3, | |
| placeholder="Type your complaint here..." | |
| ) | |
| token_counter = gr.Textbox(label="Token Count", interactive=False, value="Enter text to see token count") | |
| predict_btn = gr.Button("🔮 Predict Category", variant="primary") | |
| prediction_output = gr.Markdown("Prediction results will appear here") | |
| with gr.Tab("Predict CSV File"): | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("### Classify Multiple Complaints from CSV") | |
| gr.Markdown("Upload a CSV file with a 'complaint' column to classify multiple complaints at once.") | |
| csv_model_path = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID") | |
| csv_load_btn = gr.Button("Load Model") | |
| csv_model_status = gr.Textbox(label="Model Status", interactive=False) | |
| gr.Markdown("---") | |
| csv_file_input = gr.File(label="Upload CSV File", type="filepath", file_types=[".csv"]) | |
| csv_predict_btn = gr.Button("🔮 Predict All", variant="primary") | |
| csv_prediction_output = gr.Markdown("CSV prediction results will appear here") | |
| csv_download = gr.File(label="Download Results", interactive=False) | |
| with gr.Tab("Push to Hub"): | |
| gr.Markdown("## 🤗 Push Trained Model to Hugging Face Hub") | |
| gr.Markdown("Upload your locally trained model to the Hugging Face Hub for sharing.") | |
| with gr.Column(variant="panel"): | |
| hub_model_path = gr.Textbox(label="Local Model Path", value=MODEL_PATH) | |
| hub_username = gr.Textbox(label="Hugging Face Username") | |
| hub_model_name = gr.Textbox(label="Model Name", value="bert-complaint-classifier") | |
| hub_token = gr.Textbox(label="Hugging Face Token", type="password") | |
| push_hub_btn = gr.Button("🚀 Push to Hub", variant="primary") | |
| push_hub_output = gr.Markdown("Push results will appear here") | |
| with gr.Tab("Dataset Info"): | |
| gr.Markdown("## 📊 Dataset Information") | |
| gr.Markdown("View information about available datasets and model categories.") | |
| with gr.Column(variant="panel"): | |
| gr.Markdown("### 🎯 Model Categories") | |
| categories_info = gr.Markdown(f"**Available Categories:**\n\n" + "\n".join([f"- **{cat}** (index: {idx})" for idx, cat in idx_to_category.items()])) | |
| gr.Markdown("---") | |
| gr.Markdown("### 📁 Available Datasets") | |
| datasets_btn = gr.Button("🔍 Scan for CSV Files") | |
| datasets_info = gr.Markdown("Click 'Scan for CSV Files' to see available datasets") | |
| gr.Markdown("---") | |
| gr.Markdown("### 💡 Tips") | |
| gr.Markdown(""" | |
| **Dataset Format:** | |
| - CSV file with at least two columns | |
| - One column for text (complaints) | |
| - One column for labels/categories | |
| - Labels can be text (will be auto-mapped) or numeric indices (0, 1, 2) | |
| **Training Tips:** | |
| - Start with 3 epochs and adjust based on results | |
| - Use batch size 8-16 for most datasets | |
| - Learning rate 2e-5 works well for BERT fine-tuning | |
| - Enable early stopping to prevent overfitting | |
| **Token Limits:** | |
| - BERT has a 512 token limit | |
| - Long texts will be automatically truncated | |
| - Monitor the token counter when entering text | |
| """) | |
| # Connect functions to UI components | |
| preview_btn.click( | |
| preview_dataset, | |
| inputs=[uploaded_file, text_column_input, label_column_input], | |
| outputs=preview_output | |
| ) | |
| train_btn.click( | |
| train_model_inline, | |
| inputs=[ | |
| uploaded_file, | |
| text_column_input, | |
| label_column_input, | |
| num_epochs_slider, | |
| batch_size_slider, | |
| learning_rate_slider, | |
| hf_token_input, | |
| push_to_hub_checkbox, | |
| hf_username_input, | |
| hf_model_name_input, | |
| ], | |
| outputs=training_log_output, | |
| ) | |
| load_model_btn.click( | |
| load_model, | |
| inputs=model_path_input, | |
| outputs=model_status | |
| ) | |
| predict_btn.click( | |
| predict_text, | |
| inputs=[text_input, model_path_input], | |
| outputs=prediction_output | |
| ) | |
| text_input.change( | |
| count_tokens, | |
| inputs=text_input, | |
| outputs=token_counter | |
| ) | |
| csv_load_btn.click( | |
| load_model, | |
| inputs=csv_model_path, | |
| outputs=csv_model_status | |
| ) | |
| csv_predict_btn.click( | |
| predict_csv, | |
| inputs=[csv_file_input, csv_model_path], | |
| outputs=[csv_prediction_output, csv_download] | |
| ) | |
| push_hub_btn.click( | |
| push_to_hub_after_training, | |
| inputs=[hub_model_path, hub_username, hub_model_name, hub_token], | |
| outputs=push_hub_output | |
| ) | |
| datasets_btn.click( | |
| display_available_datasets, | |
| outputs=datasets_info | |
| ) | |
| # Run a check for available datasets on app load | |
| app.load(display_available_datasets, outputs=datasets_info) | |
| # Launch the Gradio app | |
| if __name__ == "__main__": | |
| app.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False, | |
| show_error=True | |
| ) |