Spaces:
Sleeping
Sleeping
| """ | |
| BART Fine-Tuning Engine with LoRA | |
| This module provides fine-tuning capabilities for the BART zero-shot classifier | |
| using Parameter-Efficient Fine-Tuning (PEFT) with LoRA (Low-Rank Adaptation). | |
| """ | |
| import os | |
| import json | |
| import numpy as np | |
| from datetime import datetime | |
| from typing import List, Dict, Tuple, Optional | |
| import warnings | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, | |
| AutoModelForSequenceClassification, | |
| Trainer, | |
| TrainingArguments, | |
| EarlyStoppingCallback, | |
| TrainerCallback, | |
| TrainerState, | |
| TrainerControl | |
| ) | |
| from peft import LoraConfig, get_peft_model, TaskType | |
| from datasets import Dataset | |
| from sklearn.model_selection import train_test_split | |
| from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix | |
| import logging | |
| # Suppress expected warnings | |
| warnings.filterwarnings('ignore', message='.*num_labels.*incompatible.*') | |
| warnings.filterwarnings('ignore', message='.*missing keys.*checkpoint.*') | |
| logger = logging.getLogger(__name__) | |
| class ProgressCallback(TrainerCallback): | |
| """Callback to track training progress and update database""" | |
| def __init__(self, run_id: int): | |
| self.run_id = run_id | |
| def on_epoch_begin(self, args, state: TrainerState, control: TrainerControl, **kwargs): | |
| """Called at the beginning of an epoch""" | |
| try: | |
| from app import create_app, db | |
| from app.models.models import FineTuningRun | |
| app = create_app() | |
| with app.app_context(): | |
| run = FineTuningRun.query.get(self.run_id) | |
| if run: | |
| run.current_epoch = int(state.epoch) if state.epoch else 0 | |
| run.progress_message = f"Starting epoch {run.current_epoch + 1}/{run.total_epochs}" | |
| db.session.commit() | |
| except Exception as e: | |
| logger.error(f"Error updating progress on epoch begin: {e}") | |
| def on_step_end(self, args, state: TrainerState, control: TrainerControl, **kwargs): | |
| """Called at the end of a training step""" | |
| try: | |
| # Update every 5 steps to avoid too many DB writes | |
| if state.global_step % 5 == 0: | |
| from app import create_app, db | |
| from app.models.models import FineTuningRun | |
| app = create_app() | |
| with app.app_context(): | |
| run = FineTuningRun.query.get(self.run_id) | |
| if run: | |
| run.current_step = state.global_step | |
| run.current_epoch = int(state.epoch) if state.epoch else 0 | |
| # Get current loss if available | |
| if state.log_history: | |
| last_log = state.log_history[-1] | |
| if 'loss' in last_log: | |
| run.current_loss = last_log['loss'] | |
| # Calculate progress percentage | |
| if run.total_steps and run.total_steps > 0: | |
| progress_pct = (state.global_step / run.total_steps) * 100 | |
| run.progress_message = f"Epoch {run.current_epoch + 1}/{run.total_epochs} - Step {state.global_step}/{run.total_steps} ({progress_pct:.1f}%)" | |
| if run.current_loss: | |
| run.progress_message += f" - Loss: {run.current_loss:.4f}" | |
| db.session.commit() | |
| except Exception as e: | |
| logger.error(f"Error updating progress on step end: {e}") | |
| def on_log(self, args, state: TrainerState, control: TrainerControl, logs=None, **kwargs): | |
| """Called when logging occurs""" | |
| try: | |
| from app import create_app, db | |
| from app.models.models import FineTuningRun | |
| app = create_app() | |
| with app.app_context(): | |
| run = FineTuningRun.query.get(self.run_id) | |
| if run and logs: | |
| if 'loss' in logs: | |
| run.current_loss = logs['loss'] | |
| db.session.commit() | |
| except Exception as e: | |
| logger.error(f"Error updating progress on log: {e}") | |
| class BARTFineTuner: | |
| """Fine-tune BART model for multi-class classification using LoRA""" | |
| def __init__(self, base_model_name: str = "facebook/bart-large-mnli"): | |
| """ | |
| Initialize the fine-tuner. | |
| Args: | |
| base_model_name: Hugging Face model ID for the base model | |
| """ | |
| self.base_model_name = base_model_name | |
| self.tokenizer = None | |
| self.model = None | |
| self.categories = ['Vision', 'Problem', 'Objectives', 'Directives', 'Values', 'Actions'] | |
| self.label2id = {label: idx for idx, label in enumerate(self.categories)} | |
| self.id2label = {idx: label for idx, label in enumerate(self.categories)} | |
| def prepare_dataset( | |
| self, | |
| training_examples: List[Dict], | |
| train_split: float = 0.7, | |
| val_split: float = 0.15, | |
| test_split: float = 0.15, | |
| random_state: int = 42 | |
| ) -> Tuple[Dataset, Dataset, Dataset]: | |
| """ | |
| Prepare training, validation, and test datasets from training examples. | |
| Args: | |
| training_examples: List of dicts with 'message' and 'corrected_category' | |
| train_split: Proportion for training set | |
| val_split: Proportion for validation set | |
| test_split: Proportion for test set | |
| random_state: Random seed for reproducibility | |
| Returns: | |
| Tuple of (train_dataset, val_dataset, test_dataset) | |
| """ | |
| logger.info(f"Preparing dataset from {len(training_examples)} examples") | |
| # Extract texts and labels | |
| texts = [ex['message'] for ex in training_examples] | |
| labels = [self.label2id[ex['corrected_category']] for ex in training_examples] | |
| # Validate splits | |
| assert abs(train_split + val_split + test_split - 1.0) < 0.01, "Splits must sum to 1.0" | |
| num_classes = len(self.categories) | |
| total_examples = len(texts) | |
| # Calculate minimum examples needed for stratified split | |
| # Need at least num_classes examples in each split | |
| min_test_size = int(total_examples * test_split) | |
| min_val_size = int(total_examples * val_split) | |
| # Check if we have enough examples for stratification | |
| use_stratify = (min_test_size >= num_classes and min_val_size >= num_classes) | |
| if not use_stratify: | |
| logger.warning(f"Dataset too small ({total_examples} examples) for stratified split. " | |
| f"Using random split instead.") | |
| # First split: separate test set | |
| train_val_texts, test_texts, train_val_labels, test_labels = train_test_split( | |
| texts, labels, | |
| test_size=test_split, | |
| random_state=random_state, | |
| stratify=labels if use_stratify else None | |
| ) | |
| # Second split: separate train and validation | |
| val_size_adjusted = val_split / (train_split + val_split) | |
| train_texts, val_texts, train_labels, val_labels = train_test_split( | |
| train_val_texts, train_val_labels, | |
| test_size=val_size_adjusted, | |
| random_state=random_state, | |
| stratify=train_val_labels if use_stratify else None | |
| ) | |
| # Tokenize datasets | |
| train_dataset = self._create_dataset(train_texts, train_labels) | |
| val_dataset = self._create_dataset(val_texts, val_labels) | |
| test_dataset = self._create_dataset(test_texts, test_labels) | |
| logger.info(f"Dataset prepared: train={len(train_dataset)}, " | |
| f"val={len(val_dataset)}, test={len(test_dataset)}") | |
| return train_dataset, val_dataset, test_dataset | |
| def _create_dataset(self, texts: List[str], labels: List[int]) -> Dataset: | |
| """Create a Hugging Face Dataset with tokenized texts""" | |
| # Load tokenizer if not already loaded | |
| if self.tokenizer is None: | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.base_model_name) | |
| # Tokenize | |
| encodings = self.tokenizer( | |
| texts, | |
| truncation=True, | |
| padding='max_length', | |
| max_length=128, | |
| return_tensors='pt' | |
| ) | |
| # Create dataset | |
| dataset_dict = { | |
| 'input_ids': encodings['input_ids'], | |
| 'attention_mask': encodings['attention_mask'], | |
| 'labels': torch.tensor(labels) | |
| } | |
| return Dataset.from_dict(dataset_dict) | |
| def setup_head_only_model(self) -> None: | |
| """ | |
| Set up BART model for classification head-only fine-tuning. | |
| Freezes the encoder and only trains the classification head. | |
| Better for small datasets (<100 examples). | |
| """ | |
| logger.info("Setting up BART model for head-only training") | |
| # Load base model | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| self.base_model_name, | |
| num_labels=len(self.categories), | |
| id2label=self.id2label, | |
| label2id=self.label2id, | |
| problem_type="single_label_classification", | |
| ignore_mismatched_sizes=True | |
| ) | |
| # Freeze all parameters except classification head | |
| for name, param in self.model.named_parameters(): | |
| if 'classification_head' in name or 'classifier' in name: | |
| param.requires_grad = True | |
| else: | |
| param.requires_grad = False | |
| # Count trainable parameters | |
| trainable = sum(p.numel() for p in self.model.parameters() if p.requires_grad) | |
| total = sum(p.numel() for p in self.model.parameters()) | |
| logger.info(f"Trainable params: {trainable:,} / {total:,} ({100 * trainable / total:.2f}%)") | |
| def setup_lora_model(self, lora_config: Dict) -> None: | |
| """ | |
| Set up BART model with LoRA adapters. | |
| Args: | |
| lora_config: Dict with LoRA hyperparameters: | |
| - r: Rank of update matrices (default: 16) | |
| - lora_alpha: Scaling factor (default: 32) | |
| - lora_dropout: Dropout probability (default: 0.1) | |
| - target_modules: Modules to apply LoRA to | |
| """ | |
| logger.info("Setting up BART model with LoRA") | |
| # Load base model for sequence classification | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| self.base_model_name, | |
| num_labels=len(self.categories), | |
| id2label=self.id2label, | |
| label2id=self.label2id, | |
| problem_type="single_label_classification", | |
| ignore_mismatched_sizes=True # BART-MNLI has 3 classes, we need 6 | |
| ) | |
| # Configure LoRA | |
| peft_config = LoraConfig( | |
| task_type=TaskType.SEQ_CLS, | |
| inference_mode=False, | |
| r=lora_config.get('r', 16), | |
| lora_alpha=lora_config.get('lora_alpha', 32), | |
| lora_dropout=lora_config.get('lora_dropout', 0.1), | |
| target_modules=lora_config.get('target_modules', ['q_proj', 'v_proj']), | |
| bias="none" | |
| ) | |
| # Apply PEFT | |
| self.model = get_peft_model(self.model, peft_config) | |
| self.model.print_trainable_parameters() | |
| logger.info("LoRA model ready") | |
| def train( | |
| self, | |
| train_dataset: Dataset, | |
| val_dataset: Dataset, | |
| output_dir: str, | |
| training_config: Dict, | |
| run_id: Optional[int] = None | |
| ) -> Dict: | |
| """ | |
| Train the model with LoRA. | |
| Args: | |
| train_dataset: Training dataset | |
| val_dataset: Validation dataset | |
| output_dir: Directory to save model checkpoints | |
| training_config: Training hyperparameters: | |
| - learning_rate: Learning rate (default: 3e-4) | |
| - num_epochs: Number of training epochs (default: 3) | |
| - batch_size: Per-device batch size (default: 8) | |
| - warmup_ratio: Warmup ratio (default: 0.1) | |
| Returns: | |
| Dict with training metrics | |
| """ | |
| logger.info("Starting training") | |
| # Create output directory | |
| os.makedirs(output_dir, exist_ok=True) | |
| # Force CPU training to avoid cuDNN compatibility issues on WSL2 | |
| use_cuda = False | |
| logger.info("Using CPU for training (CUDA disabled to avoid compatibility issues)") | |
| # Training arguments | |
| training_args = TrainingArguments( | |
| output_dir=output_dir, | |
| num_train_epochs=training_config.get('num_epochs', 3), | |
| per_device_train_batch_size=training_config.get('batch_size', 8), | |
| per_device_eval_batch_size=training_config.get('batch_size', 8), | |
| learning_rate=training_config.get('learning_rate', 3e-4), | |
| warmup_ratio=training_config.get('warmup_ratio', 0.1), | |
| weight_decay=0.01, | |
| logging_dir=f'{output_dir}/logs', | |
| logging_steps=10, | |
| eval_strategy="epoch", | |
| save_strategy="epoch", | |
| load_best_model_at_end=True, | |
| metric_for_best_model="eval_loss", | |
| greater_is_better=False, | |
| save_total_limit=2, | |
| report_to="none", # Disable wandb, tensorboard | |
| use_cpu=not use_cuda, # Use CPU if CUDA test fails | |
| fp16=use_cuda, # Only use mixed precision with working CUDA | |
| ) | |
| # Calculate total steps for progress tracking | |
| num_epochs = training_config.get('num_epochs', 3) | |
| batch_size = training_config.get('batch_size', 8) | |
| total_steps = (len(train_dataset) // batch_size) * num_epochs | |
| # Update run with total steps and epochs if run_id provided | |
| if run_id: | |
| try: | |
| from app import create_app, db | |
| from app.models.models import FineTuningRun | |
| app = create_app() | |
| with app.app_context(): | |
| run = FineTuningRun.query.get(run_id) | |
| if run: | |
| run.total_epochs = num_epochs | |
| run.total_steps = total_steps | |
| db.session.commit() | |
| except Exception as e: | |
| logger.error(f"Error updating run totals: {e}") | |
| # Prepare callbacks | |
| callbacks = [EarlyStoppingCallback(early_stopping_patience=2)] | |
| if run_id: | |
| callbacks.append(ProgressCallback(run_id)) | |
| # Trainer | |
| trainer = Trainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| eval_dataset=val_dataset, | |
| tokenizer=self.tokenizer, | |
| callbacks=callbacks | |
| ) | |
| # Train | |
| train_result = trainer.train() | |
| # Save model | |
| trainer.save_model(output_dir) | |
| self.tokenizer.save_pretrained(output_dir) | |
| # Extract metrics | |
| metrics = { | |
| 'train_loss': train_result.metrics.get('train_loss'), | |
| 'train_runtime': train_result.metrics.get('train_runtime'), | |
| 'train_samples_per_second': train_result.metrics.get('train_samples_per_second'), | |
| } | |
| # Validation metrics | |
| eval_metrics = trainer.evaluate() | |
| metrics['val_loss'] = eval_metrics.get('eval_loss') | |
| logger.info(f"Training complete: {metrics}") | |
| return metrics | |
| def evaluate( | |
| self, | |
| test_dataset: Dataset, | |
| model_path: Optional[str] = None | |
| ) -> Dict: | |
| """ | |
| Evaluate model on test set. | |
| Args: | |
| test_dataset: Test dataset | |
| model_path: Path to saved model (if None, uses current model) | |
| Returns: | |
| Dict with evaluation metrics | |
| """ | |
| logger.info("Evaluating model") | |
| # Load model if path provided | |
| if model_path and os.path.exists(model_path): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| model_path, | |
| num_labels=len(self.categories), | |
| ignore_mismatched_sizes=True | |
| ) | |
| # Make predictions | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| self.model.to(device) | |
| self.model.eval() | |
| predictions = [] | |
| true_labels = [] | |
| with torch.no_grad(): | |
| for i in range(len(test_dataset)): | |
| # Get the data - handle both tensor and list formats | |
| item = test_dataset[i] | |
| # Convert to tensors if needed | |
| input_ids = torch.tensor(item['input_ids']) if isinstance(item['input_ids'], list) else item['input_ids'] | |
| attention_mask = torch.tensor(item['attention_mask']) if isinstance(item['attention_mask'], list) else item['attention_mask'] | |
| label = torch.tensor(item['labels']) if isinstance(item['labels'], list) else item['labels'] | |
| # Create batch | |
| batch = { | |
| 'input_ids': input_ids.unsqueeze(0).to(device), | |
| 'attention_mask': attention_mask.unsqueeze(0).to(device) | |
| } | |
| outputs = self.model(**batch) | |
| pred = torch.argmax(outputs.logits, dim=1).item() | |
| predictions.append(pred) | |
| true_labels.append(label.item() if isinstance(label, torch.Tensor) else label) | |
| # Calculate metrics | |
| accuracy = accuracy_score(true_labels, predictions) | |
| precision, recall, f1, _ = precision_recall_fscore_support( | |
| true_labels, predictions, average='macro', zero_division=0 | |
| ) | |
| # Per-category metrics | |
| precision_per_cat, recall_per_cat, f1_per_cat, _ = precision_recall_fscore_support( | |
| true_labels, predictions, average=None, zero_division=0, labels=range(len(self.categories)) | |
| ) | |
| per_category_metrics = {} | |
| for idx, category in enumerate(self.categories): | |
| per_category_metrics[category] = { | |
| 'precision': float(precision_per_cat[idx]), | |
| 'recall': float(recall_per_cat[idx]), | |
| 'f1': float(f1_per_cat[idx]) | |
| } | |
| # Confusion matrix | |
| cm = confusion_matrix(true_labels, predictions, labels=range(len(self.categories))) | |
| metrics = { | |
| 'test_accuracy': float(accuracy), | |
| 'test_precision_macro': float(precision), | |
| 'test_recall_macro': float(recall), | |
| 'test_f1_macro': float(f1), | |
| 'per_category': per_category_metrics, | |
| 'confusion_matrix': cm.tolist() | |
| } | |
| logger.info(f"Evaluation complete: accuracy={accuracy:.3f}, f1={f1:.3f}") | |
| return metrics | |
| def compare_to_baseline( | |
| self, | |
| test_texts: List[str], | |
| test_labels: List[str] | |
| ) -> float: | |
| """ | |
| Compare fine-tuned model performance to baseline zero-shot classifier. | |
| Args: | |
| test_texts: Test text samples | |
| test_labels: True category labels | |
| Returns: | |
| Improvement in accuracy over baseline | |
| """ | |
| logger.info("Comparing to baseline model") | |
| # Load baseline zero-shot classifier | |
| from transformers import pipeline | |
| baseline_classifier = pipeline( | |
| "zero-shot-classification", | |
| model=self.base_model_name, | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| # Get baseline predictions | |
| candidate_labels = [ | |
| f"{cat}: {desc}" | |
| for cat, desc in zip( | |
| self.categories, | |
| [ | |
| "future aspirations, desired outcomes, what success looks like", | |
| "current issues, frustrations, causes of problems", | |
| "specific goals to achieve", | |
| "restrictions or requirements for solution design", | |
| "principles or restrictions for setting objectives", | |
| "concrete steps, interventions, or activities to implement" | |
| ] | |
| ) | |
| ] | |
| baseline_preds = [] | |
| for text in test_texts: | |
| result = baseline_classifier(text, candidate_labels, multi_label=False) | |
| top_label = result['labels'][0].split(':')[0] | |
| baseline_preds.append(top_label) | |
| baseline_accuracy = accuracy_score(test_labels, baseline_preds) | |
| # Get fine-tuned model predictions (already evaluated) | |
| # This is a simplified comparison - in practice, reuse evaluation results | |
| logger.info(f"Baseline accuracy: {baseline_accuracy:.3f}") | |
| return baseline_accuracy | |
| def save_metrics(self, metrics: Dict, output_path: str) -> None: | |
| """Save metrics to JSON file""" | |
| with open(output_path, 'w') as f: | |
| json.dump(metrics, f, indent=2) | |
| logger.info(f"Metrics saved to {output_path}") | |
| def export_model(self, model_path: str, export_path: str) -> None: | |
| """ | |
| Export model for deployment or backup. | |
| Args: | |
| model_path: Path to saved model | |
| export_path: Path to export directory | |
| """ | |
| import shutil | |
| logger.info(f"Exporting model from {model_path} to {export_path}") | |
| os.makedirs(export_path, exist_ok=True) | |
| # Copy model files | |
| for file in os.listdir(model_path): | |
| src = os.path.join(model_path, file) | |
| dst = os.path.join(export_path, file) | |
| if os.path.isfile(src): | |
| shutil.copy2(src, dst) | |
| logger.info("Model exported successfully") | |