contexto-api / src /fine_tuner.py
Dev-ks04
feat: Contexto FastAPI backend - intent-aware summarization engine
39028c9
"""
Fine-tuning module for domain-specific models
"""
import logging
from typing import List, Dict, Optional
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments
import json
logger = logging.getLogger(__name__)
class SummarizationDataset(Dataset):
"""Dataset for summarization fine-tuning."""
def __init__(self, documents: List[str], summaries: List[str], tokenizer, max_len: int = 512):
"""
Initialize dataset.
Args:
documents: List of documents
summaries: List of corresponding summaries
tokenizer: HuggingFace tokenizer
max_len: Maximum token length
"""
self.documents = documents
self.summaries = summaries
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.documents)
def __getitem__(self, idx):
document = str(self.documents[idx])
summary = str(self.summaries[idx])
# Tokenize input
inputs = self.tokenizer(
document,
max_length=self.max_len,
truncation=True,
padding='max_length',
return_tensors='pt'
)
# Tokenize target
targets = self.tokenizer(
summary,
max_length=256,
truncation=True,
padding='max_length',
return_tensors='pt'
)
return {
'input_ids': inputs['input_ids'].squeeze(),
'attention_mask': inputs['attention_mask'].squeeze(),
'labels': targets['input_ids'].squeeze(),
}
class FineTuner:
"""Fine-tune summarization model on custom data."""
def __init__(self, model, tokenizer, device: str = 'cuda'):
"""
Initialize fine-tuner.
Args:
model: Pre-trained model
tokenizer: Tokenizer
device: Device to use
"""
self.model = model
self.tokenizer = tokenizer
self.device = device
def prepare_data(
self,
data_file: str
) -> tuple:
"""
Prepare data from JSON file.
Format: [{"document": "...", "summary": "..."}, ...]
Args:
data_file: Path to JSON file
Returns:
Tuple of (documents, summaries)
"""
with open(data_file, 'r', encoding='utf-8') as f:
data = json.load(f)
documents = [item['document'] for item in data]
summaries = [item['summary'] for item in data]
logger.info(f"Loaded {len(documents)} training examples")
return documents, summaries
def fine_tune(
self,
documents: List[str],
summaries: List[str],
output_dir: str = 'models/fine_tuned',
num_epochs: int = 3,
batch_size: int = 8,
learning_rate: float = 2e-5
) -> str:
"""
Fine-tune model on custom data.
Args:
documents: Training documents
summaries: Training summaries
output_dir: Output directory
num_epochs: Number of training epochs
batch_size: Batch size
learning_rate: Learning rate
Returns:
Path to fine-tuned model
"""
# Create dataset
train_dataset = SummarizationDataset(
documents,
summaries,
self.tokenizer
)
# Training arguments
training_args = Seq2SeqTrainingArguments(
output_dir=output_dir,
num_train_epochs=num_epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
save_steps=len(train_dataset) // batch_size,
save_total_limit=2,
logging_steps=10,
)
# Trainer
trainer = Seq2SeqTrainer(
model=self.model,
args=training_args,
train_dataset=train_dataset,
tokenizer=self.tokenizer,
)
# Train
logger.info("Starting fine-tuning...")
trainer.train()
# Save
trainer.save_model(output_dir)
logger.info(f"Fine-tuned model saved to {output_dir}")
return output_dir
def quick_fine_tune(
self,
data_file: str,
output_dir: str = 'models/fine_tuned'
) -> str:
"""
Quick fine-tuning from JSON file.
Args:
data_file: Path to training data JSON
output_dir: Output directory
Returns:
Path to fine-tuned model
"""
documents, summaries = self.prepare_data(data_file)
return self.fine_tune(
documents,
summaries,
output_dir=output_dir,
num_epochs=2,
batch_size=4
)