Nexuss-Transformer / multi_task.py
Nexuss0781's picture
Phase 4: Add multi-task learning, P-Tuning, SI/LwF continual learning, automated tests, deployment templates
f42d9a1
"""
Multi-Task Learning Implementation for NTF
Supports task-specific heads for different fine-tuning objectives
"""
import torch
import torch.nn as nn
from typing import Dict, List, Optional, Any, Union
from dataclasses import dataclass, field
from enum import Enum
class TaskType(str, Enum):
"""Supported task types for multi-task learning."""
CLASSIFICATION = "classification"
SEQUENCE_TO_SEQUENCE = "sequence_to_sequence"
TOKEN_CLASSIFICATION = "token_classification"
QUESTION_ANSWERING = "question_answering"
GENERATION = "generation"
@dataclass
class TaskHeadConfig:
"""Configuration for a task-specific head."""
task_name: str
head_type: TaskType
config: Dict[str, Any] = field(default_factory=dict)
def __post_init__(self):
if isinstance(self.head_type, str):
self.head_type = TaskType(self.head_type)
class ClassificationHead(nn.Module):
"""Classification head for sequence classification tasks."""
def __init__(
self,
hidden_size: int,
num_labels: int,
dropout: float = 0.1,
**kwargs
):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_size, num_labels)
self.num_labels = num_labels
def forward(self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
# Use pooled output (last hidden state of [CLS] or mean pooling)
if attention_mask is not None:
# Mean pooling with mask
mask_expanded = attention_mask.unsqueeze(-1).expand(hidden_states.size()).float()
sum_embeddings = (hidden_states * mask_expanded).sum(1)
sum_mask = mask_expanded.sum(1).clamp(min=1e-9)
pooled_output = sum_embeddings / sum_mask
else:
pooled_output = hidden_states[:, -1, :] # Last token
pooled_output = self.dropout(pooled_output)
return self.classifier(pooled_output)
class SequenceToSequenceHead(nn.Module):
"""Sequence-to-sequence head for generation tasks."""
def __init__(
self,
hidden_size: int,
vocab_size: int,
max_length: int = 512,
**kwargs
):
super().__init__()
self.output_projection = nn.Linear(hidden_size, vocab_size)
self.max_length = max_length
self.vocab_size = vocab_size
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.output_projection(hidden_states)
class TokenClassificationHead(nn.Module):
"""Token-level classification head (NER, POS tagging, etc.)."""
def __init__(
self,
hidden_size: int,
num_labels: int,
dropout: float = 0.1,
**kwargs
):
super().__init__()
self.dropout = nn.Dropout(dropout)
self.classifier = nn.Linear(hidden_size, num_labels)
self.num_labels = num_labels
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states = self.dropout(hidden_states)
return self.classifier(hidden_states)
class QuestionAnsweringHead(nn.Module):
"""Head for extractive question answering."""
def __init__(
self,
hidden_size: int,
dropout: float = 0.1,
**kwargs
):
super().__init__()
self.qa_outputs = nn.Linear(hidden_size, 2) # start and end logits
def forward(self, hidden_states: torch.Tensor) -> tuple:
logits = self.qa_outputs(hidden_states)
start_logits, end_logits = logits.split(1, dim=-1)
return start_logits.squeeze(-1), end_logits.squeeze(-1)
class TaskHead(nn.Module):
"""Wrapper for task-specific heads."""
HEAD_CLASSES = {
TaskType.CLASSIFICATION: ClassificationHead,
TaskType.SEQUENCE_TO_SEQUENCE: SequenceToSequenceHead,
TaskType.TOKEN_CLASSIFICATION: TokenClassificationHead,
TaskType.QUESTION_ANSWERING: QuestionAnsweringHead,
}
def __init__(self, config: TaskHeadConfig, hidden_size: int, vocab_size: Optional[int] = None):
super().__init__()
self.config = config
self.task_name = config.task_name
self.head_type = config.head_type
head_config = dict(config.config)
head_config["hidden_size"] = hidden_size
if vocab_size is not None:
head_config["vocab_size"] = vocab_size
head_class = self.HEAD_CLASSES.get(head_type)
if head_class is None:
raise ValueError(f"Unsupported task type: {head_type}")
self.head = head_class(**head_config)
def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
return self.head(hidden_states, **kwargs)
class MultiTaskModel(nn.Module):
"""
Multi-task model with task-specific heads sharing a common base.
Args:
base_model: Base transformer model
base_model_name: Name or path of base model
"""
def __init__(self, base_model=None, base_model_name: Optional[str] = None):
super().__init__()
if base_model is None and base_model_name is None:
raise ValueError("Must provide either base_model or base_model_name")
if base_model is None:
from transformers import AutoModelForCausalLM
self.base_model = AutoModelForCausalLM.from_pretrained(base_model_name)
else:
self.base_model = base_model
# Get hidden size from base model
self.hidden_size = getattr(self.base_model.config, 'hidden_size', 768)
self.vocab_size = getattr(self.base_model.config, 'vocab_size', None)
# Task heads registry
self.task_heads: Dict[str, TaskHead] = nn.ModuleDict()
self.active_task: Optional[str] = None
# Task weights for balanced training
self.task_weights: Dict[str, float] = {}
def add_task_head(
self,
task_name: str,
head_type: Union[str, TaskType],
config: Optional[Dict[str, Any]] = None
):
"""
Add a task-specific head to the model.
Args:
task_name: Unique name for this task
head_type: Type of task (classification, seq2seq, etc.)
config: Task-specific configuration
"""
if config is None:
config = {}
task_config = TaskHeadConfig(
task_name=task_name,
head_type=head_type,
config=config
)
task_head = TaskHead(task_config, self.hidden_size, self.vocab_size)
self.task_heads[task_name] = task_head
self.task_weights[task_name] = 1.0 # Default equal weight
def set_task_weights(self, weights: Dict[str, float]):
"""Set weights for each task in multi-task training."""
for task_name, weight in weights.items():
if task_name in self.task_heads:
self.task_weights[task_name] = weight
def set_active_task(self, task_name: str):
"""Set the currently active task for single-task inference."""
if task_name not in self.task_heads:
raise ValueError(f"Task '{task_name}' not found. Available: {list(self.task_heads.keys())}")
self.active_task = task_name
def forward(
self,
input_ids: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
task_name: Optional[str] = None,
**kwargs
) -> Dict[str, torch.Tensor]:
"""
Forward pass through base model and task head.
Args:
input_ids: Input token IDs
attention_mask: Attention mask
labels: Optional labels for loss computation
task_name: Task to use (overrides active_task)
Returns:
Dictionary containing logits and optionally loss
"""
# Determine which task to use
task = task_name or self.active_task
if task is None and len(self.task_heads) == 1:
task = list(self.task_heads.keys())[0]
elif task is None:
raise ValueError("No task specified and multiple heads available")
if task not in self.task_heads:
raise ValueError(f"Task '{task}' not found")
# Get base model outputs
base_outputs = self.base_model(
input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
**kwargs
)
# Get last hidden state
hidden_states = base_outputs.hidden_states[-1]
# Apply task head
head = self.task_heads[task]
head_output = head(hidden_states, attention_mask=attention_mask)
result = {"logits": head_output}
# Compute loss if labels provided
if labels is not None:
if head.head_type == TaskType.CLASSIFICATION:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(head_output.view(-1, head.num_labels), labels.view(-1))
elif head.head_type == TaskType.SEQUENCE_TO_SEQUENCE:
shift_logits = head_output[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
else:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(head_output.view(-1, head_output.size(-1)), labels.view(-1))
result["loss"] = loss
return result
def get_num_tasks(self) -> int:
"""Return number of task heads."""
return len(self.task_heads)
def list_tasks(self) -> List[str]:
"""Return list of task names."""
return list(self.task_heads.keys())
class MultiTaskTrainer:
"""
Trainer for multi-task learning with task-balanced loss.
Args:
model: MultiTaskModel instance
task_datasets: Dictionary mapping task names to datasets
task_weights: Optional dictionary of task weights
"""
def __init__(
self,
model: MultiTaskModel,
task_datasets: Dict[str, Any],
task_weights: Optional[Dict[str, float]] = None,
tokenizer=None,
device: Optional[torch.device] = None
):
self.model = model
self.task_datasets = task_datasets
self.tokenizer = tokenizer
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set task weights
if task_weights:
self.model.set_task_weights(task_weights)
# Move model to device
self.model.to(self.device)
def train_epoch(
self,
optimizer: torch.optim.Optimizer,
batch_sizes: Dict[str, int] = None,
gradient_accumulation_steps: int = 1
) -> Dict[str, float]:
"""
Train one epoch across all tasks.
Args:
optimizer: Optimizer for training
batch_sizes: Batch size per task
gradient_accumulation_steps: Steps before optimizer update
Returns:
Dictionary of losses per task
"""
self.model.train()
task_losses = {task: 0.0 for task in self.task_datasets.keys()}
task_counts = {task: 0 for task in self.task_datasets.keys()}
# Simple round-robin training across tasks
for task_name, dataset in self.task_datasets.items():
weight = self.model.task_weights.get(task_name, 1.0)
for batch in dataset:
# Move batch to device
inputs = {k: v.to(self.device) if hasattr(v, 'to') else v
for k, v in batch.items()}
optimizer.zero_grad()
# Forward pass
outputs = self.model(
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
labels=inputs.get("labels"),
task_name=task_name
)
loss = outputs["loss"] * weight
loss.backward()
optimizer.step()
task_losses[task_name] += loss.item() / weight
task_counts[task_name] += 1
# Average losses
avg_losses = {
task: task_losses[task] / max(task_counts[task], 1)
for task in task_losses
}
return avg_losses
def evaluate(
self,
eval_datasets: Dict[str, Any],
metrics_fn: Optional[Dict[str, callable]] = None
) -> Dict[str, Dict[str, float]]:
"""
Evaluate model on all tasks.
Args:
eval_datasets: Evaluation datasets per task
metrics_fn: Optional metric functions per task
Returns:
Dictionary of metrics per task
"""
self.model.eval()
results = {}
with torch.no_grad():
for task_name, dataset in eval_datasets.items():
task_results = {"loss": 0.0, "count": 0}
for batch in dataset:
inputs = {k: v.to(self.device) if hasattr(v, 'to') else v
for k, v in batch.items()}
outputs = self.model(
input_ids=inputs.get("input_ids"),
attention_mask=inputs.get("attention_mask"),
labels=inputs.get("labels"),
task_name=task_name
)
task_results["loss"] += outputs["loss"].item()
task_results["count"] += 1
if task_results["count"] > 0:
task_results["loss"] /= task_results["count"]
results[task_name] = task_results
return results