File size: 3,384 Bytes
9cc28a7 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 | import torch
from torch.utils.data import Dataset
import json
import random
class MTPDataset(Dataset):
"""Dataset mejorado con augmentación de datos"""
def __init__(self, corpus_path, tokenizer, max_seq_len=512,
use_augmentation=False, augmentation_prob=0.3):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.use_augmentation = use_augmentation
self.augmentation_prob = augmentation_prob
self.data = []
# Load corpus
with open(corpus_path, 'r', encoding='utf-8') as f:
for line in f:
entry = json.loads(line)
if 'instruction' in entry and 'response' in entry:
self.data.append(entry)
print(f"✓ Loaded {len(self.data)} examples from corpus")
if use_augmentation:
print(f"✓ Data augmentation enabled (prob={augmentation_prob})")
def __len__(self):
return len(self.data)
def augment_text(self, text):
"""Augmentación simple de texto"""
if not self.use_augmentation or random.random() > self.augmentation_prob:
return text
# Variación 1: Agregar espacios aleatorios (simula variaciones en formato)
if random.random() < 0.3:
text = text.strip()
# Variación 2: Cambiar puntuación final
if random.random() < 0.2:
if text.endswith('.'):
text = text[:-1]
elif not text.endswith(('.', '!', '?')):
text = text + '.'
return text
def __getitem__(self, idx):
entry = self.data[idx]
instruction = entry['instruction']
response = entry['response']
# Aplicar augmentación
instruction = self.augment_text(instruction)
response = self.augment_text(response)
# Formato mejorado
full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
# Tokenize
tokens = self.tokenizer.encode(full_text)
# Add BOS and EOS
tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
# Truncate if too long
if len(tokens) > self.max_seq_len:
# Truncar manteniendo BOS y EOS
tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
# Convert to tensor
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
target_ids = torch.tensor(tokens[1:], dtype=torch.long)
return input_ids, target_ids
def collate_fn(batch, pad_id=0):
"""Custom collate function con padding inteligente"""
input_ids = [item[0] for item in batch]
target_ids = [item[1] for item in batch]
# Find max length in batch
max_len = max(len(ids) for ids in input_ids)
# Pad sequences
input_ids_padded = []
target_ids_padded = []
for inp, tgt in zip(input_ids, target_ids):
pad_len = max_len - len(inp)
input_ids_padded.append(torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)]))
target_ids_padded.append(torch.cat([tgt, torch.full((pad_len,), pad_id, dtype=torch.long)]))
return torch.stack(input_ids_padded), torch.stack(target_ids_padded) |