File size: 11,213 Bytes
c8b77b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
"""
Memory Training Module for MangoMAS Local

This module implements specialized training for memory and context retention capabilities,
adapted from the AWS backup system for local training.
"""

import json
import logging
import os
import random
from typing import Any, Dict, List

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from ..core_framework import SpecializedTrainingModule, TrainingModuleConfig

logger = logging.getLogger(__name__)


class MemoryDataset(Dataset):
    """Dataset for training memory and context retention capabilities."""

    def __init__(self, data_path: str, tokenizer, max_length: int = 1024):
        """
        Initialize the memory dataset.

        Args:
            data_path: Path to the memory training data file
            tokenizer: Tokenizer for text processing
            max_length: Maximum sequence length
        """
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.data = self._load_data(data_path)

        logger.info(f"Loaded memory dataset with {len(self.data)} examples")

    def _load_data(self, data_path: str) -> List[Dict]:
        """Load memory training data."""
        data = []
        with open(data_path, "r", encoding="utf-8") as f:
            for line in f:
                try:
                    item = json.loads(line.strip())
                    # Validate required fields
                    if "conversation" in item and isinstance(
                        item["conversation"], list
                    ):
                        data.append(item)
                except json.JSONDecodeError:
                    continue
        return data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Format the conversation for memory training
        conversation = item["conversation"]
        context = "\n".join(
            [f"{turn['role']}: {turn['content']}" for turn in conversation[:-1]]
        )
        target = conversation[-1]["content"]

        prompt = f"Context:\n{context}\nResponse: {target}"

        # Tokenize
        encoding = self.tokenizer(
            prompt,
            max_length=self.max_length,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": encoding["input_ids"].squeeze(),
        }


class MemoryTrainingModule(SpecializedTrainingModule):
    """Specialized training module for memory and context retention capabilities."""

    def __init__(self, config: TrainingModuleConfig, tokenizer):
        """
        Initialize the memory training module.

        Args:
            config: Module configuration
            tokenizer: Tokenizer for text processing
        """
        super().__init__(config, tokenizer)

        # Initialize memory-specific components
        self.memory_loss = nn.CrossEntropyLoss(ignore_index=-100)
        self.metrics = {
            "memory_loss": 0.0,
            "context_retention": 0.0,
            "coherence_score": 0.0,
        }

        logger.info("Initialized MemoryTrainingModule")

    def prepare_batch(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Prepare a batch of data for memory training.

        Args:
            batch: The input batch from the dataloader

        Returns:
            Processed batch ready for memory training
        """
        # Move batch to device
        prepared_batch = {}
        for key, value in batch.items():
            if isinstance(value, torch.Tensor):
                prepared_batch[key] = value.to(self.device)
            else:
                prepared_batch[key] = value

        return prepared_batch

    def compute_loss(
        self, student_outputs: Any, teacher_outputs: Any, batch: Dict[str, torch.Tensor]
    ) -> torch.Tensor:
        """
        Compute the memory-specific loss.

        Args:
            student_outputs: Outputs from the student model
            teacher_outputs: Outputs from the teacher model
            batch: The processed input batch

        Returns:
            Loss tensor for memory training
        """
        try:
            # Extract logits from model outputs
            if hasattr(student_outputs, "logits"):
                student_logits = student_outputs.logits
            else:
                student_logits = student_outputs

            if hasattr(teacher_outputs, "logits"):
                teacher_logits = teacher_outputs.logits
            else:
                teacher_logits = teacher_outputs

            # Get labels from batch
            labels = batch.get("labels", batch.get("input_ids"))

            # Compute cross entropy loss for memory
            shift_logits = student_logits[..., :-1, :].contiguous()
            shift_labels = labels[..., 1:].contiguous()

            memory_loss = self.memory_loss(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )

            # Add KL divergence loss between student and teacher
            if teacher_logits is not None:
                kl_loss = F.kl_div(
                    F.log_softmax(student_logits, dim=-1),
                    F.softmax(teacher_logits, dim=-1),
                    reduction="batchmean",
                )
                total_loss = memory_loss + 0.1 * kl_loss
            else:
                total_loss = memory_loss

            # Update metrics
            self.metrics["memory_loss"] = memory_loss.item()

            return total_loss * self.loss_weight

        except Exception as e:
            logger.error(f"Error computing memory loss: {e}")
            # Return a small loss to avoid training failure
            return torch.tensor(0.01, requires_grad=True)

    def get_metrics(self) -> Dict[str, float]:
        """
        Get metrics specific to memory training.

        Returns:
            Dictionary of memory metrics
        """
        return self.metrics.copy()

    def generate_synthetic_memory_data(
        self, output_path: str, num_samples: int = 1000
    ) -> None:
        """
        Generate synthetic memory training data.

        Args:
            output_path: Path to save the generated data
            num_samples: Number of samples to generate
        """
        # This is a simplified implementation based on the AWS backup
        # In a full implementation, this would be much more sophisticated

        conversation_templates = [
            [
                {
                    "role": "user",
                    "content": "Hi, my name is Alex and I'm interested in machine learning.",
                },
                {
                    "role": "assistant",
                    "content": "Hello Alex! I'd be happy to discuss machine learning with you. What aspects are you most interested in?",
                },
                {
                    "role": "user",
                    "content": "I'm particularly interested in natural language processing.",
                },
                {
                    "role": "assistant",
                    "content": "NLP is a fascinating field! It's used for tasks like translation, summarization, and question answering.",
                },
                {
                    "role": "user",
                    "content": "What do you think would be a good first project?",
                },
                {
                    "role": "assistant",
                    "content": "For a beginner in NLP, I'd recommend starting with a text classification project, like sentiment analysis.",
                },
            ],
            [
                {
                    "role": "user",
                    "content": "I'm planning a trip to Japan next spring.",
                },
                {
                    "role": "assistant",
                    "content": "That sounds exciting! Japan is beautiful in spring with cherry blossoms. What cities are you planning to visit?",
                },
                {
                    "role": "user",
                    "content": "I'm thinking Tokyo, Kyoto, and maybe Osaka.",
                },
                {
                    "role": "assistant",
                    "content": "Great choices! Tokyo has modern attractions, Kyoto has historical temples, and Osaka is known for amazing food.",
                },
                {
                    "role": "user",
                    "content": "What's the best way to travel between these cities?",
                },
                {
                    "role": "assistant",
                    "content": "The Shinkansen (bullet train) is the most efficient way to travel between these cities. It's fast, comfortable, and reliable.",
                },
            ],
        ]

        recall_templates = [
            {
                "recall_context": "what was my name again?",
                "recall_target": "Your name is Alex, as you mentioned at the beginning of our conversation.",
            },
            {
                "recall_context": "which cities did I say I wanted to visit?",
                "recall_target": "You mentioned you're planning to visit Tokyo, Kyoto, and possibly Osaka during your trip to Japan.",
            },
        ]

        # Generate variations
        output_data = []
        for _ in range(num_samples):
            template_idx = random.randint(0, len(conversation_templates) - 1)
            conversation = conversation_templates[template_idx].copy()

            # Add a recall question if this is the right template
            if template_idx < len(recall_templates):
                recall_template = recall_templates[template_idx]

                # Add a user question asking for recall
                conversation.append(
                    {"role": "user", "content": recall_template["recall_context"]}
                )

                # Create the full example with recall targets
                example = {
                    "conversation": conversation,
                    "recall_context": recall_template["recall_context"],
                    "recall_target": recall_template["recall_target"],
                    "metadata": {"generated": True, "requires_memory": True},
                }
            else:
                # Regular conversation without specific recall target
                example = {
                    "conversation": conversation,
                    "metadata": {"generated": True, "requires_memory": False},
                }

            output_data.append(example)

        # Save to file
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        with open(output_path, "w", encoding="utf-8") as f:
            for item in output_data:
                f.write(json.dumps(item) + "\n")

        logger.info(
            f"Generated {len(output_data)} synthetic memory examples at {output_path}"
        )