File size: 3,384 Bytes
9cc28a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import torch
from torch.utils.data import Dataset
import json
import random


class MTPDataset(Dataset):
    """Dataset mejorado con augmentación de datos"""
    
    def __init__(self, corpus_path, tokenizer, max_seq_len=512, 
                 use_augmentation=False, augmentation_prob=0.3):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.use_augmentation = use_augmentation
        self.augmentation_prob = augmentation_prob
        self.data = []
        
        # Load corpus
        with open(corpus_path, 'r', encoding='utf-8') as f:
            for line in f:
                entry = json.loads(line)
                if 'instruction' in entry and 'response' in entry:
                    self.data.append(entry)
        
        print(f"✓ Loaded {len(self.data)} examples from corpus")
        if use_augmentation:
            print(f"✓ Data augmentation enabled (prob={augmentation_prob})")
    
    def __len__(self):
        return len(self.data)
    
    def augment_text(self, text):
        """Augmentación simple de texto"""
        if not self.use_augmentation or random.random() > self.augmentation_prob:
            return text
        
        # Variación 1: Agregar espacios aleatorios (simula variaciones en formato)
        if random.random() < 0.3:
            text = text.strip()
        
        # Variación 2: Cambiar puntuación final
        if random.random() < 0.2:
            if text.endswith('.'):
                text = text[:-1]
            elif not text.endswith(('.', '!', '?')):
                text = text + '.'
        
        return text
    
    def __getitem__(self, idx):
        entry = self.data[idx]
        
        instruction = entry['instruction']
        response = entry['response']
        
        # Aplicar augmentación
        instruction = self.augment_text(instruction)
        response = self.augment_text(response)
        
        # Formato mejorado
        full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
        
        # Tokenize
        tokens = self.tokenizer.encode(full_text)
        
        # Add BOS and EOS
        tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
        
        # Truncate if too long
        if len(tokens) > self.max_seq_len:
            # Truncar manteniendo BOS y EOS
            tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
        
        # Convert to tensor
        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
        target_ids = torch.tensor(tokens[1:], dtype=torch.long)
        
        return input_ids, target_ids


def collate_fn(batch, pad_id=0):
    """Custom collate function con padding inteligente"""
    input_ids = [item[0] for item in batch]
    target_ids = [item[1] for item in batch]
    
    # Find max length in batch
    max_len = max(len(ids) for ids in input_ids)
    
    # Pad sequences
    input_ids_padded = []
    target_ids_padded = []
    
    for inp, tgt in zip(input_ids, target_ids):
        pad_len = max_len - len(inp)
        input_ids_padded.append(torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)]))
        target_ids_padded.append(torch.cat([tgt, torch.full((pad_len,), pad_id, dtype=torch.long)]))
    
    return torch.stack(input_ids_padded), torch.stack(target_ids_padded)