| | import torch
|
| | from torch.utils.data import Dataset
|
| | import json
|
| | import random
|
| |
|
| |
|
| | class MTPDataset(Dataset):
|
| | """Dataset optimizado para corpus con formato instruction-context-response"""
|
| |
|
| | def __init__(self, corpus_path, tokenizer, max_seq_len=2048,
|
| | use_augmentation=True, augmentation_prob=0.3):
|
| | self.tokenizer = tokenizer
|
| | self.max_seq_len = max_seq_len
|
| | self.use_augmentation = use_augmentation
|
| | self.augmentation_prob = augmentation_prob
|
| | self.data = []
|
| |
|
| |
|
| | print(f" → Cargando corpus: {corpus_path}")
|
| | valid_count = 0
|
| | with open(corpus_path, 'r', encoding='utf-8') as f:
|
| | for line_num, line in enumerate(f, 1):
|
| | line = line.strip()
|
| | if not line:
|
| | continue
|
| |
|
| | try:
|
| | entry = json.loads(line)
|
| |
|
| |
|
| | if 'instruction' not in entry or 'response' not in entry:
|
| | print(f" ⚠ Línea {line_num}: Falta 'instruction' o 'response'")
|
| | continue
|
| |
|
| | instruction = entry['instruction'].strip()
|
| | response = entry['response'].strip()
|
| |
|
| | if not instruction or not response:
|
| | print(f" ⚠ Línea {line_num}: Campos vacíos")
|
| | continue
|
| |
|
| |
|
| | context = entry.get('context', '').strip()
|
| |
|
| | self.data.append({
|
| | 'instruction': instruction,
|
| | 'context': context,
|
| | 'response': response
|
| | })
|
| | valid_count += 1
|
| |
|
| | except json.JSONDecodeError as e:
|
| | print(f" ❌ Línea {line_num}: JSON inválido - {e}")
|
| | continue
|
| |
|
| | print(f" ✓ Cargados {valid_count} ejemplos válidos de {line_num} líneas")
|
| | if use_augmentation:
|
| | print(f" ✓ Augmentación activada (prob={augmentation_prob})")
|
| |
|
| | def __len__(self):
|
| | return len(self.data)
|
| |
|
| | def augment_text(self, text):
|
| | """Augmentación mejorada de texto"""
|
| | if not self.use_augmentation or random.random() > self.augmentation_prob or not text:
|
| | return text
|
| |
|
| |
|
| | text = text.strip()
|
| |
|
| |
|
| | if random.random() < 0.25:
|
| | if text.endswith('.'):
|
| | if random.random() < 0.5:
|
| | text = text[:-1]
|
| | elif not text.endswith(('.', '!', '?', ':')):
|
| | if random.random() < 0.5:
|
| | text = text + '.'
|
| |
|
| |
|
| | if random.random() < 0.1 and len(text) > 0:
|
| | if text[0].isupper():
|
| | text = text[0].lower() + text[1:]
|
| | elif text[0].islower():
|
| | text = text[0].upper() + text[1:]
|
| |
|
| | return text
|
| |
|
| | def __getitem__(self, idx):
|
| | entry = self.data[idx]
|
| |
|
| | instruction = entry['instruction']
|
| | context = entry['context']
|
| | response = entry['response']
|
| |
|
| |
|
| | instruction = self.augment_text(instruction)
|
| | context = self.augment_text(context)
|
| | response = self.augment_text(response)
|
| |
|
| |
|
| | if context:
|
| | full_text = f"### Instrucción:\n{instruction}\n\n### Contexto:\n{context}\n\n### Respuesta:\n{response}"
|
| | else:
|
| | full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
|
| |
|
| |
|
| | tokens = self.tokenizer.encode(full_text)
|
| |
|
| |
|
| | tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
|
| |
|
| |
|
| | if len(tokens) > self.max_seq_len:
|
| | tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
|
| |
|
| | input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
|
| | target_ids = torch.tensor(tokens[1:], dtype=torch.long)
|
| |
|
| | return input_ids, target_ids
|
| |
|
| |
|
| | def collate_fn(batch, pad_id=0):
|
| | """Collate function optimizada con padding dinámico"""
|
| | input_ids = [item[0] for item in batch]
|
| | target_ids = [item[1] for item in batch]
|
| |
|
| |
|
| | max_len = max(len(ids) for ids in input_ids)
|
| |
|
| |
|
| | input_ids_padded = []
|
| | target_ids_padded = []
|
| |
|
| | for inp, tgt in zip(input_ids, target_ids):
|
| | pad_len = max_len - len(inp)
|
| |
|
| |
|
| | input_ids_padded.append(
|
| | torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])
|
| | )
|
| |
|
| |
|
| | target_ids_padded.append(
|
| | torch.cat([tgt, torch.full((pad_len,), -100, dtype=torch.long)])
|
| | )
|
| |
|
| | return torch.stack(input_ids_padded), torch.stack(target_ids_padded) |