File size: 5,608 Bytes
563bb6a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | import torch
from torch.utils.data import Dataset
import json
import random
class MTPDataset(Dataset):
"""Dataset optimizado para corpus con formato instruction-context-response"""
def __init__(self, corpus_path, tokenizer, max_seq_len=2048,
use_augmentation=True, augmentation_prob=0.3):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.use_augmentation = use_augmentation
self.augmentation_prob = augmentation_prob
self.data = []
# Load corpus
print(f" → Cargando corpus: {corpus_path}")
valid_count = 0
with open(corpus_path, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
line = line.strip()
if not line:
continue
try:
entry = json.loads(line)
# Validar campos requeridos
if 'instruction' not in entry or 'response' not in entry:
print(f" ⚠ Línea {line_num}: Falta 'instruction' o 'response'")
continue
instruction = entry['instruction'].strip()
response = entry['response'].strip()
if not instruction or not response:
print(f" ⚠ Línea {line_num}: Campos vacíos")
continue
# Context puede estar vacío o no existir
context = entry.get('context', '').strip()
self.data.append({
'instruction': instruction,
'context': context,
'response': response
})
valid_count += 1
except json.JSONDecodeError as e:
print(f" ❌ Línea {line_num}: JSON inválido - {e}")
continue
print(f" ✓ Cargados {valid_count} ejemplos válidos de {line_num} líneas")
if use_augmentation:
print(f" ✓ Augmentación activada (prob={augmentation_prob})")
def __len__(self):
return len(self.data)
def augment_text(self, text):
"""Augmentación mejorada de texto"""
if not self.use_augmentation or random.random() > self.augmentation_prob or not text:
return text
# 1. Variación en espacios
text = text.strip()
# 2. Variación en puntuación final
if random.random() < 0.25:
if text.endswith('.'):
if random.random() < 0.5:
text = text[:-1]
elif not text.endswith(('.', '!', '?', ':')):
if random.random() < 0.5:
text = text + '.'
# 3. Variación en mayúsculas iniciales
if random.random() < 0.1 and len(text) > 0:
if text[0].isupper():
text = text[0].lower() + text[1:]
elif text[0].islower():
text = text[0].upper() + text[1:]
return text
def __getitem__(self, idx):
entry = self.data[idx]
instruction = entry['instruction']
context = entry['context']
response = entry['response']
# Aplicar augmentación
instruction = self.augment_text(instruction)
context = self.augment_text(context)
response = self.augment_text(response)
# Formato optimizado para entrenamiento con contexto opcional
if context:
full_text = f"### Instrucción:\n{instruction}\n\n### Contexto:\n{context}\n\n### Respuesta:\n{response}"
else:
full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
# Tokenize
tokens = self.tokenizer.encode(full_text)
# Add BOS and EOS
tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
# Truncate if too long (mantener BOS y EOS)
if len(tokens) > self.max_seq_len:
tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
target_ids = torch.tensor(tokens[1:], dtype=torch.long)
return input_ids, target_ids
def collate_fn(batch, pad_id=0):
"""Collate function optimizada con padding dinámico"""
input_ids = [item[0] for item in batch]
target_ids = [item[1] for item in batch]
# Find max length in this batch (dynamic padding)
max_len = max(len(ids) for ids in input_ids)
# Pad sequences
input_ids_padded = []
target_ids_padded = []
for inp, tgt in zip(input_ids, target_ids):
pad_len = max_len - len(inp)
# Pad input with pad_id
input_ids_padded.append(
torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])
)
# Pad target with -100 (ignore_index in CrossEntropyLoss)
target_ids_padded.append(
torch.cat([tgt, torch.full((pad_len,), -100, dtype=torch.long)])
)
return torch.stack(input_ids_padded), torch.stack(target_ids_padded) |