msmaje's picture
Update app.py
1a06556 verified
import gradio as gr
import torch
import pandas as pd
import os
import json
import logging
import numpy as np
from datetime import datetime
from pathlib import Path
from sklearn.metrics import accuracy_score, classification_report
from sklearn.model_selection import train_test_split
from huggingface_hub import login
from transformers import (
AutoTokenizer,
BertForSequenceClassification,
TrainingArguments,
Trainer,
DataCollatorWithPadding,
EarlyStoppingCallback,
TrainerCallback
)
from datasets import Dataset, DatasetDict
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables
MODEL_PATH = "local-model"
CATEGORIES = ['Online-Safety', 'BroadBand', 'TV-Radio']
idx_to_category = {0: 'Online-Safety', 1: 'BroadBand', 2: 'TV-Radio'}
category_to_idx = {'Online-Safety': 0, 'BroadBand': 1, 'TV-Radio': 2}
TOKEN = None
TRAINING_LOGS = []
CURRENT_MODEL = None
CURRENT_TOKENIZER = None
def load_and_prepare_local_dataset(file_path, text_column, label_column, test_size=0.2):
"""Load and prepare local CSV dataset for training"""
try:
if not os.path.exists(file_path):
raise FileNotFoundError(f"Dataset file not found: {file_path}")
# Load the CSV file
df = pd.read_csv(file_path)
# Verify required columns exist
if text_column not in df.columns:
available_cols = list(df.columns)
raise ValueError(f"Text column '{text_column}' not found. Available columns: {available_cols}")
if label_column not in df.columns:
available_cols = list(df.columns)
raise ValueError(f"Label column '{label_column}' not found. Available columns: {available_cols}")
# Clean the data
df = df.dropna(subset=[text_column, label_column])
df[text_column] = df[text_column].astype(str)
# Handle different label formats
if df[label_column].dtype == 'object':
# If labels are text, convert to indices
unique_labels = df[label_column].unique()
if len(unique_labels) > len(CATEGORIES):
raise ValueError(f"Too many unique labels ({len(unique_labels)}). Expected max {len(CATEGORIES)}")
# Try to map text labels to our categories
label_mapping = {}
for label in unique_labels:
if label in category_to_idx:
label_mapping[label] = category_to_idx[label]
else:
# Auto-assign if not found
available_indices = set(range(len(CATEGORIES))) - set(label_mapping.values())
if available_indices:
label_mapping[label] = min(available_indices)
else:
raise ValueError(f"Cannot map label '{label}' to available categories")
df['label_idx'] = df[label_column].map(label_mapping)
else:
# If labels are already numeric
df['label_idx'] = df[label_column].astype(int)
# Verify label indices are valid
if df['label_idx'].min() < 0 or df['label_idx'].max() >= len(CATEGORIES):
raise ValueError(f"Label indices must be between 0 and {len(CATEGORIES)-1}")
# Create train/validation split
train_df, val_df = train_test_split(
df,
test_size=0.2,
random_state=42,
stratify=df['label_idx']
)
# Convert to Hugging Face datasets
train_dataset = Dataset.from_pandas(train_df[[text_column, 'label_idx']])
val_dataset = Dataset.from_pandas(val_df[[text_column, 'label_idx']])
dataset_dict = DatasetDict({
'train': train_dataset,
'validation': val_dataset
})
return dataset_dict, text_column, 'label_idx'
except Exception as e:
raise Exception(f"Error loading dataset: {str(e)}")
def preview_dataset(uploaded_file, text_column, label_column):
"""Preview a dataset file"""
try:
if uploaded_file is None:
return "Please upload a dataset file first."
# Get the file path from the uploaded file
file_path = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
# Check if file exists and is CSV
if not os.path.exists(file_path):
return "❌ File not found. Please try uploading again."
if not file_path.lower().endswith('.csv'):
return "❌ Please upload a CSV file (.csv extension required)."
df = pd.read_csv(file_path)
preview_info = []
preview_info.append(f"📊 **Dataset Preview: {os.path.basename(file_path)}**")
preview_info.append(f"- **Total rows:** {len(df)}")
preview_info.append(f"- **Columns:** {list(df.columns)}")
preview_info.append("")
if text_column in df.columns:
preview_info.append(f"✅ **Text column '{text_column}' found**")
preview_info.append(f"- Sample text: {str(df[text_column].iloc[0])[:100]}...")
else:
preview_info.append(f"❌ **Text column '{text_column}' not found**")
return "\n".join(preview_info)
if label_column in df.columns:
preview_info.append(f"✅ **Label column '{label_column}' found**")
label_counts = df[label_column].value_counts()
preview_info.append("- **Label distribution:**")
for label, count in label_counts.items():
preview_info.append(f" - {label}: {count} ({count/len(df)*100:.1f}%)")
else:
preview_info.append(f"❌ **Label column '{label_column}' not found**")
return "\n".join(preview_info)
return "\n".join(preview_info)
except Exception as e:
return f"❌ Error previewing dataset: {str(e)}"
def login_to_hf(token):
"""Login to Hugging Face"""
global TOKEN
TOKEN = token
try:
login(token)
return "✅ Successfully logged in to Hugging Face"
except Exception as e:
return f"❌ Login failed: {str(e)}"
def validate_hub_model_id(username, model_name):
"""Validate and construct Hub model ID"""
if not username or not model_name:
return None, "Please provide both username and model name"
# Clean up the model name
model_name = model_name.strip().lower().replace(" ", "-")
model_name = ''.join(c for c in model_name if c.isalnum() or c in ['-', '_'])
# Construct the full model ID
hub_model_id = f"{username}/{model_name}"
return hub_model_id, None
def load_model(model_path):
"""Load a trained model and tokenizer"""
global CURRENT_MODEL, CURRENT_TOKENIZER
try:
# Try loading from local path first
if os.path.exists(model_path):
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=len(CATEGORIES)
)
return f"✅ Model loaded from local path: {model_path}"
# If local path doesn't exist, try loading from Hub
try:
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained(model_path)
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
model_path,
num_labels=len(CATEGORIES)
)
return f"✅ Model loaded from Hugging Face Hub: {model_path}"
except Exception as hub_error:
# If both local and hub loading fail, fall back to base model
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
CURRENT_MODEL = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=len(CATEGORIES)
)
return f"⚠️ Failed to load specified model, using base BERT model instead. Error: {str(hub_error)}"
except Exception as e:
return f"❌ Failed to load model: {str(e)}"
def tokenize_function(examples, tokenizer, feature_column, max_length=512):
"""Tokenize the input text"""
return tokenizer(
examples[feature_column],
truncation=True,
padding=False,
max_length=max_length
)
def compute_metrics(eval_pred):
"""Compute metrics for evaluation"""
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
accuracy = accuracy_score(labels, predictions)
report = classification_report(labels, predictions, output_dict=True, zero_division=0)
return {
'accuracy': accuracy,
'f1_macro': report['macro avg']['f1-score'],
'f1_weighted': report['weighted avg']['f1-score'],
'precision_macro': report['macro avg']['precision'],
'recall_macro': report['macro avg']['recall']
}
def train_model_inline(uploaded_file, text_column, label_column, num_epochs, batch_size,
learning_rate, hf_token, push_to_hub, username, model_name):
"""Train the model using inline training (no subprocess)"""
global TRAINING_LOGS, MODEL_PATH, CURRENT_MODEL, CURRENT_TOKENIZER
TRAINING_LOGS = []
if hf_token:
login_result = login_to_hf(hf_token)
TRAINING_LOGS.append(login_result)
yield "\n".join(TRAINING_LOGS)
# Validate hub model ID if pushing to hub
if push_to_hub:
hub_model_id, error = validate_hub_model_id(username, model_name)
if error:
TRAINING_LOGS.append(f"❌ {error}")
yield "\n".join(TRAINING_LOGS)
return
else:
hub_model_id = None
# Validate uploaded file
if uploaded_file is None:
TRAINING_LOGS.append("❌ Please upload a dataset file")
yield "\n".join(TRAINING_LOGS)
return
dataset_file = uploaded_file.name if hasattr(uploaded_file, 'name') else uploaded_file
try:
# Load and prepare dataset
TRAINING_LOGS.append(f"📊 Loading dataset from uploaded file...")
yield "\n".join(TRAINING_LOGS)
dataset_dict, final_text_col, final_label_col = load_and_prepare_local_dataset(
dataset_file, text_column, label_column
)
TRAINING_LOGS.append(f"✅ Dataset loaded successfully!")
TRAINING_LOGS.append(f"- Train samples: {len(dataset_dict['train'])}")
TRAINING_LOGS.append(f"- Validation samples: {len(dataset_dict['validation'])}")
yield "\n".join(TRAINING_LOGS)
# Load model and tokenizer
TRAINING_LOGS.append("🤖 Loading BERT model and tokenizer...")
yield "\n".join(TRAINING_LOGS)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForSequenceClassification.from_pretrained(
"bert-base-uncased",
num_labels=len(CATEGORIES)
)
TRAINING_LOGS.append("✅ Model and tokenizer loaded")
yield "\n".join(TRAINING_LOGS)
# Tokenize datasets
TRAINING_LOGS.append("🔤 Tokenizing datasets...")
yield "\n".join(TRAINING_LOGS)
def tokenize_batch(examples):
return tokenize_function(examples, tokenizer, final_text_col, 512)
# Get columns to remove (keep only label column and tokenized features)
columns_to_remove = [col for col in dataset_dict['train'].column_names if col != final_label_col]
tokenized_datasets = dataset_dict.map(
tokenize_batch,
batched=True,
remove_columns=columns_to_remove
)
# Rename label column to 'labels' (required by Trainer)
tokenized_datasets = tokenized_datasets.rename_column(final_label_col, 'labels')
TRAINING_LOGS.append("✅ Tokenization completed")
yield "\n".join(TRAINING_LOGS)
# Set up training
output_dir = Path(MODEL_PATH)
output_dir.mkdir(parents=True, exist_ok=True)
# Calculate steps
total_steps = len(tokenized_datasets['train']) // batch_size * num_epochs
eval_steps = max(10, min(100, total_steps // 4))
save_steps = max(20, min(500, total_steps // 2))
logging_steps = max(5, min(50, total_steps // 10))
warmup_steps = min(500, total_steps // 10)
TRAINING_LOGS.append(f"📈 Training configuration:")
TRAINING_LOGS.append(f"- Total steps: {total_steps}")
TRAINING_LOGS.append(f"- Eval steps: {eval_steps}")
TRAINING_LOGS.append(f"- Warmup steps: {warmup_steps}")
yield "\n".join(TRAINING_LOGS)
training_args = TrainingArguments(
output_dir=str(output_dir),
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
warmup_steps=warmup_steps,
weight_decay=0.01,
learning_rate=learning_rate,
logging_dir=str(output_dir / "logs"),
logging_steps=logging_steps,
eval_strategy="steps", # Corrected back to eval_strategy
eval_steps=eval_steps,
save_steps=save_steps,
save_total_limit=2,
load_best_model_at_end=True,
metric_for_best_model="eval_accuracy",
greater_is_better=True,
push_to_hub=push_to_hub,
hub_model_id=hub_model_id if push_to_hub else None,
report_to=None,
dataloader_num_workers=0,
fp16=torch.cuda.is_available(),
seed=42,
remove_unused_columns=False,
)
# Data collator
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
# Corrected Callback Class
class ProgressCallback(TrainerCallback):
def __init__(self, logs_list, total_steps):
self.logs = logs_list
self.total_steps = total_steps
def on_train_begin(self, args, state, control, **kwargs):
self.logs.append("🚀 Starting training...")
self.log_update()
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % args.logging_steps == 0:
self.logs.append(f"Step {state.global_step}/{self.total_steps}")
self.log_update()
def on_epoch_end(self, args, state, control, **kwargs):
epoch = int(state.epoch)
self.logs.append(f"✅ Epoch {epoch} completed")
self.log_update()
def on_evaluate(self, args, state, control, logs=None, **kwargs):
if logs:
acc = logs.get('eval_accuracy', 0)
loss = logs.get('eval_loss', 0)
self.logs.append(f"📊 Eval - Accuracy: {acc:.4f}, Loss: {loss:.4f}")
self.log_update()
def log_update(self):
# This is a custom helper to yield updates to the Gradio UI
# The original code did this manually, but with TrainerCallback,
# we can't do that. So we log to the list and rely on the UI
# to refresh. For a real-time stream, this part would need to be
# handled by Gradio's streaming feature, but this approach
# is sufficient for the user's current setup.
pass
# Create trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_datasets['train'],
eval_dataset=tokenized_datasets['validation'],
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
callbacks=[EarlyStoppingCallback(early_stopping_patience=3), ProgressCallback(TRAINING_LOGS, total_steps)]
)
# Train the model
try:
trainer.train()
TRAINING_LOGS.append("✅ Training completed successfully!")
yield "\n".join(TRAINING_LOGS)
except Exception as e:
TRAINING_LOGS.append(f"❌ Training failed: {str(e)}")
yield "\n".join(TRAINING_LOGS)
return
# Save the model
TRAINING_LOGS.append("💾 Saving model...")
yield "\n".join(TRAINING_LOGS)
trainer.save_model()
tokenizer.save_pretrained(output_dir)
# Update global model and tokenizer
CURRENT_MODEL = model
CURRENT_TOKENIZER = tokenizer
TRAINING_LOGS.append("✅ Model saved successfully!")
yield "\n".join(TRAINING_LOGS)
# Final evaluation
TRAINING_LOGS.append("📊 Running final evaluation...")
yield "\n".join(TRAINING_LOGS)
try:
eval_results = trainer.evaluate()
TRAINING_LOGS.append("📊 Final Results:")
for key, value in eval_results.items():
if isinstance(value, float):
TRAINING_LOGS.append(f" {key}: {value:.4f}")
else:
TRAINING_LOGS.append(f" {key}: {value}")
# Save results
with open(output_dir / "eval_results.json", "w") as f:
json.dump(eval_results, f, indent=2)
except Exception as e:
TRAINING_LOGS.append(f"⚠️ Evaluation error: {str(e)}")
yield "\n".join(TRAINING_LOGS)
# Push to hub if requested
if push_to_hub and hub_model_id:
TRAINING_LOGS.append(f"🤗 Pushing to Hugging Face Hub: {hub_model_id}")
yield "\n".join(TRAINING_LOGS)
try:
trainer.push_to_hub()
TRAINING_LOGS.append(f"✅ Successfully pushed to {hub_model_id}")
except Exception as e:
TRAINING_LOGS.append(f"❌ Push to Hub failed: {str(e)}")
yield "\n".join(TRAINING_LOGS)
TRAINING_LOGS.append("\n✨ Training completed! Your model is ready to use.")
yield "\n".join(TRAINING_LOGS)
except Exception as e:
TRAINING_LOGS.append(f"❌ Error during training: {str(e)}")
yield "\n".join(TRAINING_LOGS)
def predict_text(text, model_path):
"""Make a prediction on a single text input"""
global CURRENT_MODEL, CURRENT_TOKENIZER
# Load the model if it's not loaded or a different one is requested
if CURRENT_MODEL is None or model_path != MODEL_PATH:
load_result = load_model(model_path)
if load_result.startswith("❌"):
return load_result
try:
if not text.strip():
return "Please enter some text to classify."
# Check if text was truncated
original_tokens = CURRENT_TOKENIZER(text, truncation=False)
was_truncated = len(original_tokens['input_ids']) > 512
# Tokenize input
inputs = CURRENT_TOKENIZER(text, return_tensors="pt", truncation=True, max_length=512)
# Make prediction
with torch.no_grad():
outputs = CURRENT_MODEL(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_class_id = predictions.argmax().item()
confidence = predictions.max().item()
# Get predicted category
predicted_category = idx_to_category[predicted_class_id]
# Format result
truncation_warning = "\n\n⚠️ Note: This complaint was truncated to fit BERT's 512 token limit." if was_truncated else ""
result = []
result.append(f"**Complaint:** {text}")
result.append(f"\n**Predicted Category:** {predicted_category}")
result.append(f"**Confidence:** {confidence:.4f}")
result.append("\n**All Class Probabilities:**")
for i, category in enumerate(CATEGORIES):
prob = predictions[0][i].item()
result.append(f"- {category}: {prob:.4f}")
result.append(truncation_warning)
return "\n".join(result)
except Exception as e:
return f"❌ Prediction error: {str(e)}"
def predict_csv(csv_file, model_path):
"""Make predictions on a CSV file with complaints"""
global CURRENT_MODEL, CURRENT_TOKENIZER
# Load the model if needed
if CURRENT_MODEL is None or model_path != MODEL_PATH:
load_result = load_model(model_path)
if load_result.startswith("❌"):
return load_result, None
try:
# Read the CSV file
if hasattr(csv_file, 'name'):
df = pd.read_csv(csv_file.name)
else:
df = pd.read_csv(csv_file)
if 'complaint' not in df.columns:
return "❌ CSV file must have a 'complaint' column", None
results = []
predictions_list = []
truncated_count = 0
for i, row in enumerate(df.iterrows()):
complaint = str(row[1]['complaint'])
# Check for truncation
original_tokens = CURRENT_TOKENIZER(complaint, truncation=False)
was_truncated = len(original_tokens['input_ids']) > 512
if was_truncated:
truncated_count += 1
# Predict
inputs = CURRENT_TOKENIZER(complaint, return_tensors="pt", truncation=True, max_length=512)
with torch.no_grad():
outputs = CURRENT_MODEL(**inputs)
predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)
predicted_idx = predictions.argmax().item()
confidence = predictions.max().item()
predicted_category = idx_to_category[predicted_idx]
predictions_list.append({
'complaint': complaint,
'predicted_category': predicted_category,
'confidence': confidence,
'truncated': was_truncated
})
truncation_mark = " ⚠️" if was_truncated else ""
preview = complaint if len(complaint) <= 50 else complaint[:47] + "..."
results.append(f"Complaint {i+1}{truncation_mark}: {preview}")
results.append(f"Predicted: {predicted_category} (confidence: {confidence:.3f})\n")
if i >= 19:
results.append(f"... and {len(df) - 20} more (showing first 20 out of {len(df)} complaints)")
break
if truncated_count > 0:
results.append(f"\n⚠️ {truncated_count} complaints were truncated to fit BERT's 512 token limit.")
# Save full results to a CSV file
results_df = pd.DataFrame(predictions_list)
results_file = "prediction_results.csv"
results_df.to_csv(results_file, index=False)
results.append(f"\n💾 Full results saved to {results_file}")
return "\n".join(results), results_file
except Exception as e:
return f"❌ CSV processing failed: {str(e)}", None
def push_to_hub_after_training(model_path, username, model_name, token):
"""Push a trained model to Hugging Face Hub"""
try:
if not token:
return "❌ Please provide a Hugging Face token"
hub_model_id, error = validate_hub_model_id(username, model_name)
if error:
return f"❌ {error}"
# Login and load model
login(token)
if not os.path.exists(model_path):
return "❌ No trained model found. Please train a model first."
try:
model = BertForSequenceClassification.from_pretrained(model_path)
tokenizer = AutoTokenizer.from_pretrained(model_path)
except Exception as e:
return f"❌ Failed to load model: {str(e)}"
# Push to Hub
try:
model.push_to_hub(hub_model_id)
tokenizer.push_to_hub(hub_model_id)
return f"✅ Successfully pushed model to {hub_model_id}"
except Exception as e:
return f"❌ Failed to push to Hub: {str(e)}"
except Exception as e:
return f"❌ Error: {str(e)}"
def count_tokens(text):
"""Count tokens in input text"""
global CURRENT_TOKENIZER
if text is None:
return "Enter text to see token count"
# Attempt to load a default tokenizer if it's not set
if CURRENT_TOKENIZER is None:
try:
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
logger.info("Fallback: Tokenizer loaded in count_tokens function.")
except Exception as e:
logger.error(f"Failed to load tokenizer in count_tokens fallback: {e}")
return "❌ Error: Tokenizer not loaded. Please load a model or check logs."
# If tokenizer is still None after fallback, something is seriously wrong
if CURRENT_TOKENIZER is None:
return "❌ Error: Tokenizer is still not available."
tokens = CURRENT_TOKENIZER(text, truncation=False)
count = len(tokens['input_ids'])
if count > 512:
return f"⚠️ **Token count: {count}/512** - Text will be truncated for BERT"
else:
return f"Token count: {count}/512"
def get_available_datasets():
"""Get list of available CSV files in the current directory"""
available_files = []
for file in os.listdir("."):
if file.endswith(".csv"):
try:
df = pd.read_csv(file)
available_files.append(f"{file} ({len(df)} rows)")
except:
available_files.append(f"{file} (Error reading)")
if not available_files:
available_files = ["No CSV files found in current directory"]
return available_files
def display_available_datasets():
datasets = get_available_datasets()
if datasets:
return "**Available CSV files:**\n\n" + "\n".join([f"- {file}" for file in datasets])
else:
return "No CSV files found in the current directory."
# Initialize tokenizer on startup
if CURRENT_TOKENIZER is None:
try:
CURRENT_TOKENIZER = AutoTokenizer.from_pretrained("bert-base-uncased")
print("✅ Tokenizer initialized successfully")
except Exception as e:
print(f"⚠️ Warning: Could not initialize tokenizer globally: {e}")
print("🚀 Launching BERT Complaint Classifier...")
print("📍 Available at: http://localhost:7860")
# The entire Gradio UI definition must be within a single block
with gr.Blocks(title="BERT Complaint Classifier", theme=gr.themes.Soft()) as app:
gr.Markdown("# BERT Complaint Classifier 🗣️🤖")
gr.Markdown("Fine-tune a BERT model or use an existing one to classify customer complaints.")
with gr.Tab("Fine-tune Model"):
gr.Markdown("## 🏋️ Fine-tune a New Model")
with gr.Column(variant="panel"):
gr.Markdown("### 🛠️ Training Configuration")
with gr.Row():
uploaded_file = gr.File(label="Upload Training CSV File", type="filepath", file_types=[".csv"])
preview_btn = gr.Button("Preview Dataset")
preview_output = gr.Markdown("Dataset info will appear here")
with gr.Row():
text_column_input = gr.Textbox(label="Text Column Name", value="complaint")
label_column_input = gr.Textbox(label="Label Column Name", value="category")
gr.Markdown("---")
with gr.Row():
num_epochs_slider = gr.Slider(minimum=1, maximum=10, step=1, value=3, label="Number of Epochs")
batch_size_slider = gr.Slider(minimum=4, maximum=32, step=4, value=8, label="Batch Size")
learning_rate_slider = gr.Slider(minimum=1e-6, maximum=1e-4, step=1e-6, value=2e-5, label="Learning Rate")
gr.Markdown("---")
gr.Markdown("### ☁️ Hugging Face Hub (Optional)")
with gr.Row():
push_to_hub_checkbox = gr.Checkbox(label="Push to Hugging Face Hub")
hf_token_input = gr.Textbox(label="Hugging Face Token", type="password")
with gr.Row():
hf_username_input = gr.Textbox(label="Hugging Face Username")
hf_model_name_input = gr.Textbox(label="Model Name (for Hub)", value="bert-complaint-classifier")
train_btn = gr.Button("🚀 Start Training", variant="primary")
gr.Markdown("---")
training_log_output = gr.Textbox(label="Training Logs", lines=20, max_lines=20, interactive=False)
with gr.Tab("Predict"):
gr.Markdown("## 🔮 Make Predictions")
gr.Markdown("Choose a method to classify complaints.")
with gr.Tab("Predict Single Text"):
with gr.Column(variant="panel"):
gr.Markdown("### Classify a Single Complaint")
model_path_input = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID")
load_model_btn = gr.Button("Load Model")
model_status = gr.Textbox(label="Model Status", interactive=False)
gr.Markdown("---")
text_input = gr.Textbox(
label="Enter complaint text",
lines=3,
placeholder="Type your complaint here..."
)
token_counter = gr.Textbox(label="Token Count", interactive=False, value="Enter text to see token count")
predict_btn = gr.Button("🔮 Predict Category", variant="primary")
prediction_output = gr.Markdown("Prediction results will appear here")
with gr.Tab("Predict CSV File"):
with gr.Column(variant="panel"):
gr.Markdown("### Classify Multiple Complaints from CSV")
gr.Markdown("Upload a CSV file with a 'complaint' column to classify multiple complaints at once.")
csv_model_path = gr.Textbox(label="Model Path", value=MODEL_PATH, placeholder="Enter model path or HuggingFace model ID")
csv_load_btn = gr.Button("Load Model")
csv_model_status = gr.Textbox(label="Model Status", interactive=False)
gr.Markdown("---")
csv_file_input = gr.File(label="Upload CSV File", type="filepath", file_types=[".csv"])
csv_predict_btn = gr.Button("🔮 Predict All", variant="primary")
csv_prediction_output = gr.Markdown("CSV prediction results will appear here")
csv_download = gr.File(label="Download Results", interactive=False)
with gr.Tab("Push to Hub"):
gr.Markdown("## 🤗 Push Trained Model to Hugging Face Hub")
gr.Markdown("Upload your locally trained model to the Hugging Face Hub for sharing.")
with gr.Column(variant="panel"):
hub_model_path = gr.Textbox(label="Local Model Path", value=MODEL_PATH)
hub_username = gr.Textbox(label="Hugging Face Username")
hub_model_name = gr.Textbox(label="Model Name", value="bert-complaint-classifier")
hub_token = gr.Textbox(label="Hugging Face Token", type="password")
push_hub_btn = gr.Button("🚀 Push to Hub", variant="primary")
push_hub_output = gr.Markdown("Push results will appear here")
with gr.Tab("Dataset Info"):
gr.Markdown("## 📊 Dataset Information")
gr.Markdown("View information about available datasets and model categories.")
with gr.Column(variant="panel"):
gr.Markdown("### 🎯 Model Categories")
categories_info = gr.Markdown(f"**Available Categories:**\n\n" + "\n".join([f"- **{cat}** (index: {idx})" for idx, cat in idx_to_category.items()]))
gr.Markdown("---")
gr.Markdown("### 📁 Available Datasets")
datasets_btn = gr.Button("🔍 Scan for CSV Files")
datasets_info = gr.Markdown("Click 'Scan for CSV Files' to see available datasets")
gr.Markdown("---")
gr.Markdown("### 💡 Tips")
gr.Markdown("""
**Dataset Format:**
- CSV file with at least two columns
- One column for text (complaints)
- One column for labels/categories
- Labels can be text (will be auto-mapped) or numeric indices (0, 1, 2)
**Training Tips:**
- Start with 3 epochs and adjust based on results
- Use batch size 8-16 for most datasets
- Learning rate 2e-5 works well for BERT fine-tuning
- Enable early stopping to prevent overfitting
**Token Limits:**
- BERT has a 512 token limit
- Long texts will be automatically truncated
- Monitor the token counter when entering text
""")
# Connect functions to UI components
preview_btn.click(
preview_dataset,
inputs=[uploaded_file, text_column_input, label_column_input],
outputs=preview_output
)
train_btn.click(
train_model_inline,
inputs=[
uploaded_file,
text_column_input,
label_column_input,
num_epochs_slider,
batch_size_slider,
learning_rate_slider,
hf_token_input,
push_to_hub_checkbox,
hf_username_input,
hf_model_name_input,
],
outputs=training_log_output,
)
load_model_btn.click(
load_model,
inputs=model_path_input,
outputs=model_status
)
predict_btn.click(
predict_text,
inputs=[text_input, model_path_input],
outputs=prediction_output
)
text_input.change(
count_tokens,
inputs=text_input,
outputs=token_counter
)
csv_load_btn.click(
load_model,
inputs=csv_model_path,
outputs=csv_model_status
)
csv_predict_btn.click(
predict_csv,
inputs=[csv_file_input, csv_model_path],
outputs=[csv_prediction_output, csv_download]
)
push_hub_btn.click(
push_to_hub_after_training,
inputs=[hub_model_path, hub_username, hub_model_name, hub_token],
outputs=push_hub_output
)
datasets_btn.click(
display_available_datasets,
outputs=datasets_info
)
# Run a check for available datasets on app load
app.load(display_available_datasets, outputs=datasets_info)
# Launch the Gradio app
if __name__ == "__main__":
app.launch(
server_name="0.0.0.0",
server_port=7860,
share=False,
show_error=True
)