import torch from torch.utils.data import Dataset import json import random class MTPDataset(Dataset): """Dataset mejorado con augmentación de datos""" def __init__(self, corpus_path, tokenizer, max_seq_len=512, use_augmentation=False, augmentation_prob=0.3): self.tokenizer = tokenizer self.max_seq_len = max_seq_len self.use_augmentation = use_augmentation self.augmentation_prob = augmentation_prob self.data = [] # Load corpus with open(corpus_path, 'r', encoding='utf-8') as f: for line in f: entry = json.loads(line) if 'instruction' in entry and 'response' in entry: self.data.append(entry) print(f"✓ Loaded {len(self.data)} examples from corpus") if use_augmentation: print(f"✓ Data augmentation enabled (prob={augmentation_prob})") def __len__(self): return len(self.data) def augment_text(self, text): """Augmentación simple de texto""" if not self.use_augmentation or random.random() > self.augmentation_prob: return text # Variación 1: Agregar espacios aleatorios (simula variaciones en formato) if random.random() < 0.3: text = text.strip() # Variación 2: Cambiar puntuación final if random.random() < 0.2: if text.endswith('.'): text = text[:-1] elif not text.endswith(('.', '!', '?')): text = text + '.' return text def __getitem__(self, idx): entry = self.data[idx] instruction = entry['instruction'] response = entry['response'] # Aplicar augmentación instruction = self.augment_text(instruction) response = self.augment_text(response) # Formato mejorado full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}" # Tokenize tokens = self.tokenizer.encode(full_text) # Add BOS and EOS tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()] # Truncate if too long if len(tokens) > self.max_seq_len: # Truncar manteniendo BOS y EOS tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()] # Convert to tensor input_ids = torch.tensor(tokens[:-1], dtype=torch.long) target_ids = torch.tensor(tokens[1:], dtype=torch.long) return input_ids, target_ids def collate_fn(batch, pad_id=0): """Custom collate function con padding inteligente""" input_ids = [item[0] for item in batch] target_ids = [item[1] for item in batch] # Find max length in batch max_len = max(len(ids) for ids in input_ids) # Pad sequences input_ids_padded = [] target_ids_padded = [] for inp, tgt in zip(input_ids, target_ids): pad_len = max_len - len(inp) input_ids_padded.append(torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])) target_ids_padded.append(torch.cat([tgt, torch.full((pad_len,), pad_id, dtype=torch.long)])) return torch.stack(input_ids_padded), torch.stack(target_ids_padded)