import torch from torch.utils.data import Dataset import json import random class MTPDataset(Dataset): """Dataset optimizado para corpus con formato instruction-context-response""" def __init__(self, corpus_path, tokenizer, max_seq_len=2048, use_augmentation=True, augmentation_prob=0.3): self.tokenizer = tokenizer self.max_seq_len = max_seq_len self.use_augmentation = use_augmentation self.augmentation_prob = augmentation_prob self.data = [] # Load corpus print(f" → Cargando corpus: {corpus_path}") valid_count = 0 with open(corpus_path, 'r', encoding='utf-8') as f: for line_num, line in enumerate(f, 1): line = line.strip() if not line: continue try: entry = json.loads(line) # Validar campos requeridos if 'instruction' not in entry or 'response' not in entry: print(f" ⚠ Línea {line_num}: Falta 'instruction' o 'response'") continue instruction = entry['instruction'].strip() response = entry['response'].strip() if not instruction or not response: print(f" ⚠ Línea {line_num}: Campos vacíos") continue # Context puede estar vacío o no existir context = entry.get('context', '').strip() self.data.append({ 'instruction': instruction, 'context': context, 'response': response }) valid_count += 1 except json.JSONDecodeError as e: print(f" ❌ Línea {line_num}: JSON inválido - {e}") continue print(f" ✓ Cargados {valid_count} ejemplos válidos de {line_num} líneas") if use_augmentation: print(f" ✓ Augmentación activada (prob={augmentation_prob})") def __len__(self): return len(self.data) def augment_text(self, text): """Augmentación mejorada de texto""" if not self.use_augmentation or random.random() > self.augmentation_prob or not text: return text # 1. Variación en espacios text = text.strip() # 2. Variación en puntuación final if random.random() < 0.25: if text.endswith('.'): if random.random() < 0.5: text = text[:-1] elif not text.endswith(('.', '!', '?', ':')): if random.random() < 0.5: text = text + '.' # 3. Variación en mayúsculas iniciales if random.random() < 0.1 and len(text) > 0: if text[0].isupper(): text = text[0].lower() + text[1:] elif text[0].islower(): text = text[0].upper() + text[1:] return text def __getitem__(self, idx): entry = self.data[idx] instruction = entry['instruction'] context = entry['context'] response = entry['response'] # Aplicar augmentación instruction = self.augment_text(instruction) context = self.augment_text(context) response = self.augment_text(response) # Formato optimizado para entrenamiento con contexto opcional if context: full_text = f"### Instrucción:\n{instruction}\n\n### Contexto:\n{context}\n\n### Respuesta:\n{response}" else: full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}" # Tokenize tokens = self.tokenizer.encode(full_text) # Add BOS and EOS tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()] # Truncate if too long (mantener BOS y EOS) if len(tokens) > self.max_seq_len: tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()] input_ids = torch.tensor(tokens[:-1], dtype=torch.long) target_ids = torch.tensor(tokens[1:], dtype=torch.long) return input_ids, target_ids def collate_fn(batch, pad_id=0): """Collate function optimizada con padding dinámico""" input_ids = [item[0] for item in batch] target_ids = [item[1] for item in batch] # Find max length in this batch (dynamic padding) max_len = max(len(ids) for ids in input_ids) # Pad sequences input_ids_padded = [] target_ids_padded = [] for inp, tgt in zip(input_ids, target_ids): pad_len = max_len - len(inp) # Pad input with pad_id input_ids_padded.append( torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)]) ) # Pad target with -100 (ignore_index in CrossEntropyLoss) target_ids_padded.append( torch.cat([tgt, torch.full((pad_len,), -100, dtype=torch.long)]) ) return torch.stack(input_ids_padded), torch.stack(target_ids_padded)