|
|
""" |
|
|
Memory Training Module for MangoMAS Local |
|
|
|
|
|
This module implements specialized training for memory and context retention capabilities, |
|
|
adapted from the AWS backup system for local training. |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
import os |
|
|
import random |
|
|
from typing import Any, Dict, List |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch.utils.data import Dataset |
|
|
|
|
|
from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class MemoryDataset(Dataset): |
|
|
"""Dataset for training memory and context retention capabilities.""" |
|
|
|
|
|
def __init__(self, data_path: str, tokenizer, max_length: int = 1024): |
|
|
""" |
|
|
Initialize the memory dataset. |
|
|
|
|
|
Args: |
|
|
data_path: Path to the memory training data file |
|
|
tokenizer: Tokenizer for text processing |
|
|
max_length: Maximum sequence length |
|
|
""" |
|
|
self.tokenizer = tokenizer |
|
|
self.max_length = max_length |
|
|
self.data = self._load_data(data_path) |
|
|
|
|
|
logger.info(f"Loaded memory dataset with {len(self.data)} examples") |
|
|
|
|
|
def _load_data(self, data_path: str) -> List[Dict]: |
|
|
"""Load memory training data.""" |
|
|
data = [] |
|
|
with open(data_path, "r", encoding="utf-8") as f: |
|
|
for line in f: |
|
|
try: |
|
|
item = json.loads(line.strip()) |
|
|
|
|
|
if "conversation" in item and isinstance( |
|
|
item["conversation"], list |
|
|
): |
|
|
data.append(item) |
|
|
except json.JSONDecodeError: |
|
|
continue |
|
|
return data |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.data) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
item = self.data[idx] |
|
|
|
|
|
|
|
|
conversation = item["conversation"] |
|
|
context = "\n".join( |
|
|
[f"{turn['role']}: {turn['content']}" for turn in conversation[:-1]] |
|
|
) |
|
|
target = conversation[-1]["content"] |
|
|
|
|
|
prompt = f"Context:\n{context}\nResponse: {target}" |
|
|
|
|
|
|
|
|
encoding = self.tokenizer( |
|
|
prompt, |
|
|
max_length=self.max_length, |
|
|
padding="max_length", |
|
|
truncation=True, |
|
|
return_tensors="pt", |
|
|
) |
|
|
|
|
|
return { |
|
|
"input_ids": encoding["input_ids"].squeeze(), |
|
|
"attention_mask": encoding["attention_mask"].squeeze(), |
|
|
"labels": encoding["input_ids"].squeeze(), |
|
|
} |
|
|
|
|
|
|
|
|
class MemoryTrainingModule(SpecializedTrainingModule): |
|
|
"""Specialized training module for memory and context retention capabilities.""" |
|
|
|
|
|
def __init__(self, config: TrainingModuleConfig, tokenizer): |
|
|
""" |
|
|
Initialize the memory training module. |
|
|
|
|
|
Args: |
|
|
config: Module configuration |
|
|
tokenizer: Tokenizer for text processing |
|
|
""" |
|
|
super().__init__(config, tokenizer) |
|
|
|
|
|
|
|
|
self.memory_loss = nn.CrossEntropyLoss(ignore_index=-100) |
|
|
self.metrics = { |
|
|
"memory_loss": 0.0, |
|
|
"context_retention": 0.0, |
|
|
"coherence_score": 0.0, |
|
|
} |
|
|
|
|
|
logger.info("Initialized MemoryTrainingModule") |
|
|
|
|
|
def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
|
|
""" |
|
|
Prepare a batch of data for memory training. |
|
|
|
|
|
Args: |
|
|
batch: The input batch from the dataloader |
|
|
|
|
|
Returns: |
|
|
Processed batch ready for memory training |
|
|
""" |
|
|
|
|
|
prepared_batch = {} |
|
|
for key, value in batch.items(): |
|
|
if isinstance(value, torch.Tensor): |
|
|
prepared_batch[key] = value.to(self.device) |
|
|
else: |
|
|
prepared_batch[key] = value |
|
|
|
|
|
return prepared_batch |
|
|
|
|
|
def compute_loss( |
|
|
self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor] |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Compute the memory-specific loss. |
|
|
|
|
|
Args: |
|
|
student_outputs: Outputs from the student model |
|
|
teacher_outputs: Outputs from the teacher model |
|
|
batch: The processed input batch |
|
|
|
|
|
Returns: |
|
|
Loss tensor for memory training |
|
|
""" |
|
|
try: |
|
|
|
|
|
if hasattr(student_outputs, "logits"): |
|
|
student_logits = student_outputs.logits |
|
|
else: |
|
|
student_logits = student_outputs |
|
|
|
|
|
if hasattr(teacher_outputs, "logits"): |
|
|
teacher_logits = teacher_outputs.logits |
|
|
else: |
|
|
teacher_logits = teacher_outputs |
|
|
|
|
|
|
|
|
labels = batch.get("labels", batch.get("input_ids")) |
|
|
|
|
|
|
|
|
shift_logits = student_logits[..., :-1, :].contiguous() |
|
|
shift_labels = labels[..., 1:].contiguous() |
|
|
|
|
|
memory_loss = self.memory_loss( |
|
|
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1) |
|
|
) |
|
|
|
|
|
|
|
|
if teacher_logits is not None: |
|
|
kl_loss = F.kl_div( |
|
|
F.log_softmax(student_logits, dim=-1), |
|
|
F.softmax(teacher_logits, dim=-1), |
|
|
reduction="batchmean", |
|
|
) |
|
|
total_loss = memory_loss + 0.1 * kl_loss |
|
|
else: |
|
|
total_loss = memory_loss |
|
|
|
|
|
|
|
|
self.metrics["memory_loss"] = memory_loss.item() |
|
|
|
|
|
return total_loss * self.loss_weight |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error computing memory loss: {e}") |
|
|
|
|
|
return torch.tensor(0.01, requires_grad=True) |
|
|
|
|
|
def get_metrics(self) -> Dict[str, float]: |
|
|
""" |
|
|
Get metrics specific to memory training. |
|
|
|
|
|
Returns: |
|
|
Dictionary of memory metrics |
|
|
""" |
|
|
return self.metrics.copy() |
|
|
|
|
|
def generate_synthetic_memory_data( |
|
|
self, output_path: str, num_samples: int = 1000 |
|
|
) -> None: |
|
|
""" |
|
|
Generate synthetic memory training data. |
|
|
|
|
|
Args: |
|
|
output_path: Path to save the generated data |
|
|
num_samples: Number of samples to generate |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
conversation_templates = [ |
|
|
[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": "Hi, my name is Alex and I'm interested in machine learning.", |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": "Hello Alex! I'd be happy to discuss machine learning with you. What aspects are you most interested in?", |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": "I'm particularly interested in natural language processing.", |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": "NLP is a fascinating field! It's used for tasks like translation, summarization, and question answering.", |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": "What do you think would be a good first project?", |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": "For a beginner in NLP, I'd recommend starting with a text classification project, like sentiment analysis.", |
|
|
}, |
|
|
], |
|
|
[ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": "I'm planning a trip to Japan next spring.", |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": "That sounds exciting! Japan is beautiful in spring with cherry blossoms. What cities are you planning to visit?", |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": "I'm thinking Tokyo, Kyoto, and maybe Osaka.", |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": "Great choices! Tokyo has modern attractions, Kyoto has historical temples, and Osaka is known for amazing food.", |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": "What's the best way to travel between these cities?", |
|
|
}, |
|
|
{ |
|
|
"role": "assistant", |
|
|
"content": "The Shinkansen (bullet train) is the most efficient way to travel between these cities. It's fast, comfortable, and reliable.", |
|
|
}, |
|
|
], |
|
|
] |
|
|
|
|
|
recall_templates = [ |
|
|
{ |
|
|
"recall_context": "what was my name again?", |
|
|
"recall_target": "Your name is Alex, as you mentioned at the beginning of our conversation.", |
|
|
}, |
|
|
{ |
|
|
"recall_context": "which cities did I say I wanted to visit?", |
|
|
"recall_target": "You mentioned you're planning to visit Tokyo, Kyoto, and possibly Osaka during your trip to Japan.", |
|
|
}, |
|
|
] |
|
|
|
|
|
|
|
|
output_data = [] |
|
|
for _ in range(num_samples): |
|
|
template_idx = random.randint(0, len(conversation_templates) - 1) |
|
|
conversation = conversation_templates[template_idx].copy() |
|
|
|
|
|
|
|
|
if template_idx < len(recall_templates): |
|
|
recall_template = recall_templates[template_idx] |
|
|
|
|
|
|
|
|
conversation.append( |
|
|
{"role": "user", "content": recall_template["recall_context"]} |
|
|
) |
|
|
|
|
|
|
|
|
example = { |
|
|
"conversation": conversation, |
|
|
"recall_context": recall_template["recall_context"], |
|
|
"recall_target": recall_template["recall_target"], |
|
|
"metadata": {"generated": True, "requires_memory": True}, |
|
|
} |
|
|
else: |
|
|
|
|
|
example = { |
|
|
"conversation": conversation, |
|
|
"metadata": {"generated": True, "requires_memory": False}, |
|
|
} |
|
|
|
|
|
output_data.append(example) |
|
|
|
|
|
|
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True) |
|
|
with open(output_path, "w", encoding="utf-8") as f: |
|
|
for item in output_data: |
|
|
f.write(json.dumps(item) + "\n") |
|
|
|
|
|
logger.info( |
|
|
f"Generated {len(output_data)} synthetic memory examples at {output_path}" |
|
|
) |
|
|
|