MTP3.7 / dataset.py
teszenofficial's picture
Upload 4 files
9cc28a7 verified
import torch
from torch.utils.data import Dataset
import json
import random
class MTPDataset(Dataset):
"""Dataset mejorado con augmentación de datos"""
def __init__(self, corpus_path, tokenizer, max_seq_len=512,
use_augmentation=False, augmentation_prob=0.3):
self.tokenizer = tokenizer
self.max_seq_len = max_seq_len
self.use_augmentation = use_augmentation
self.augmentation_prob = augmentation_prob
self.data = []
# Load corpus
with open(corpus_path, 'r', encoding='utf-8') as f:
for line in f:
entry = json.loads(line)
if 'instruction' in entry and 'response' in entry:
self.data.append(entry)
print(f"✓ Loaded {len(self.data)} examples from corpus")
if use_augmentation:
print(f"✓ Data augmentation enabled (prob={augmentation_prob})")
def __len__(self):
return len(self.data)
def augment_text(self, text):
"""Augmentación simple de texto"""
if not self.use_augmentation or random.random() > self.augmentation_prob:
return text
# Variación 1: Agregar espacios aleatorios (simula variaciones en formato)
if random.random() < 0.3:
text = text.strip()
# Variación 2: Cambiar puntuación final
if random.random() < 0.2:
if text.endswith('.'):
text = text[:-1]
elif not text.endswith(('.', '!', '?')):
text = text + '.'
return text
def __getitem__(self, idx):
entry = self.data[idx]
instruction = entry['instruction']
response = entry['response']
# Aplicar augmentación
instruction = self.augment_text(instruction)
response = self.augment_text(response)
# Formato mejorado
full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
# Tokenize
tokens = self.tokenizer.encode(full_text)
# Add BOS and EOS
tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
# Truncate if too long
if len(tokens) > self.max_seq_len:
# Truncar manteniendo BOS y EOS
tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
# Convert to tensor
input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
target_ids = torch.tensor(tokens[1:], dtype=torch.long)
return input_ids, target_ids
def collate_fn(batch, pad_id=0):
"""Custom collate function con padding inteligente"""
input_ids = [item[0] for item in batch]
target_ids = [item[1] for item in batch]
# Find max length in batch
max_len = max(len(ids) for ids in input_ids)
# Pad sequences
input_ids_padded = []
target_ids_padded = []
for inp, tgt in zip(input_ids, target_ids):
pad_len = max_len - len(inp)
input_ids_padded.append(torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)]))
target_ids_padded.append(torch.cat([tgt, torch.full((pad_len,), pad_id, dtype=torch.long)]))
return torch.stack(input_ids_padded), torch.stack(target_ids_padded)