Spaces:
Running
Running
| """ | |
| Fine-tuning module for domain-specific models | |
| """ | |
| import logging | |
| from typing import List, Dict, Optional | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments | |
| import json | |
| logger = logging.getLogger(__name__) | |
| class SummarizationDataset(Dataset): | |
| """Dataset for summarization fine-tuning.""" | |
| def __init__(self, documents: List[str], summaries: List[str], tokenizer, max_len: int = 512): | |
| """ | |
| Initialize dataset. | |
| Args: | |
| documents: List of documents | |
| summaries: List of corresponding summaries | |
| tokenizer: HuggingFace tokenizer | |
| max_len: Maximum token length | |
| """ | |
| self.documents = documents | |
| self.summaries = summaries | |
| self.tokenizer = tokenizer | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.documents) | |
| def __getitem__(self, idx): | |
| document = str(self.documents[idx]) | |
| summary = str(self.summaries[idx]) | |
| # Tokenize input | |
| inputs = self.tokenizer( | |
| document, | |
| max_length=self.max_len, | |
| truncation=True, | |
| padding='max_length', | |
| return_tensors='pt' | |
| ) | |
| # Tokenize target | |
| targets = self.tokenizer( | |
| summary, | |
| max_length=256, | |
| truncation=True, | |
| padding='max_length', | |
| return_tensors='pt' | |
| ) | |
| return { | |
| 'input_ids': inputs['input_ids'].squeeze(), | |
| 'attention_mask': inputs['attention_mask'].squeeze(), | |
| 'labels': targets['input_ids'].squeeze(), | |
| } | |
| class FineTuner: | |
| """Fine-tune summarization model on custom data.""" | |
| def __init__(self, model, tokenizer, device: str = 'cuda'): | |
| """ | |
| Initialize fine-tuner. | |
| Args: | |
| model: Pre-trained model | |
| tokenizer: Tokenizer | |
| device: Device to use | |
| """ | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.device = device | |
| def prepare_data( | |
| self, | |
| data_file: str | |
| ) -> tuple: | |
| """ | |
| Prepare data from JSON file. | |
| Format: [{"document": "...", "summary": "..."}, ...] | |
| Args: | |
| data_file: Path to JSON file | |
| Returns: | |
| Tuple of (documents, summaries) | |
| """ | |
| with open(data_file, 'r', encoding='utf-8') as f: | |
| data = json.load(f) | |
| documents = [item['document'] for item in data] | |
| summaries = [item['summary'] for item in data] | |
| logger.info(f"Loaded {len(documents)} training examples") | |
| return documents, summaries | |
| def fine_tune( | |
| self, | |
| documents: List[str], | |
| summaries: List[str], | |
| output_dir: str = 'models/fine_tuned', | |
| num_epochs: int = 3, | |
| batch_size: int = 8, | |
| learning_rate: float = 2e-5 | |
| ) -> str: | |
| """ | |
| Fine-tune model on custom data. | |
| Args: | |
| documents: Training documents | |
| summaries: Training summaries | |
| output_dir: Output directory | |
| num_epochs: Number of training epochs | |
| batch_size: Batch size | |
| learning_rate: Learning rate | |
| Returns: | |
| Path to fine-tuned model | |
| """ | |
| # Create dataset | |
| train_dataset = SummarizationDataset( | |
| documents, | |
| summaries, | |
| self.tokenizer | |
| ) | |
| # Training arguments | |
| training_args = Seq2SeqTrainingArguments( | |
| output_dir=output_dir, | |
| num_train_epochs=num_epochs, | |
| per_device_train_batch_size=batch_size, | |
| learning_rate=learning_rate, | |
| save_steps=len(train_dataset) // batch_size, | |
| save_total_limit=2, | |
| logging_steps=10, | |
| ) | |
| # Trainer | |
| trainer = Seq2SeqTrainer( | |
| model=self.model, | |
| args=training_args, | |
| train_dataset=train_dataset, | |
| tokenizer=self.tokenizer, | |
| ) | |
| # Train | |
| logger.info("Starting fine-tuning...") | |
| trainer.train() | |
| # Save | |
| trainer.save_model(output_dir) | |
| logger.info(f"Fine-tuned model saved to {output_dir}") | |
| return output_dir | |
| def quick_fine_tune( | |
| self, | |
| data_file: str, | |
| output_dir: str = 'models/fine_tuned' | |
| ) -> str: | |
| """ | |
| Quick fine-tuning from JSON file. | |
| Args: | |
| data_file: Path to training data JSON | |
| output_dir: Output directory | |
| Returns: | |
| Path to fine-tuned model | |
| """ | |
| documents, summaries = self.prepare_data(data_file) | |
| return self.fine_tune( | |
| documents, | |
| summaries, | |
| output_dir=output_dir, | |
| num_epochs=2, | |
| batch_size=4 | |
| ) | |