File size: 5,608 Bytes
563bb6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import torch
from torch.utils.data import Dataset
import json
import random


class MTPDataset(Dataset):
    """Dataset optimizado para corpus con formato instruction-context-response"""
    
    def __init__(self, corpus_path, tokenizer, max_seq_len=2048, 

                 use_augmentation=True, augmentation_prob=0.3):
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len
        self.use_augmentation = use_augmentation
        self.augmentation_prob = augmentation_prob
        self.data = []
        
        # Load corpus
        print(f"   → Cargando corpus: {corpus_path}")
        valid_count = 0
        with open(corpus_path, 'r', encoding='utf-8') as f:
            for line_num, line in enumerate(f, 1):
                line = line.strip()
                if not line:
                    continue
                    
                try:
                    entry = json.loads(line)
                    
                    # Validar campos requeridos
                    if 'instruction' not in entry or 'response' not in entry:
                        print(f"   ⚠ Línea {line_num}: Falta 'instruction' o 'response'")
                        continue
                    
                    instruction = entry['instruction'].strip()
                    response = entry['response'].strip()
                    
                    if not instruction or not response:
                        print(f"   ⚠ Línea {line_num}: Campos vacíos")
                        continue
                    
                    # Context puede estar vacío o no existir
                    context = entry.get('context', '').strip()
                    
                    self.data.append({
                        'instruction': instruction,
                        'context': context,
                        'response': response
                    })
                    valid_count += 1
                    
                except json.JSONDecodeError as e:
                    print(f"   ❌ Línea {line_num}: JSON inválido - {e}")
                    continue
        
        print(f"   ✓ Cargados {valid_count} ejemplos válidos de {line_num} líneas")
        if use_augmentation:
            print(f"   ✓ Augmentación activada (prob={augmentation_prob})")
    
    def __len__(self):
        return len(self.data)
    
    def augment_text(self, text):
        """Augmentación mejorada de texto"""
        if not self.use_augmentation or random.random() > self.augmentation_prob or not text:
            return text
        
        # 1. Variación en espacios
        text = text.strip()
        
        # 2. Variación en puntuación final
        if random.random() < 0.25:
            if text.endswith('.'):
                if random.random() < 0.5:
                    text = text[:-1]
            elif not text.endswith(('.', '!', '?', ':')):
                if random.random() < 0.5:
                    text = text + '.'
        
        # 3. Variación en mayúsculas iniciales
        if random.random() < 0.1 and len(text) > 0:
            if text[0].isupper():
                text = text[0].lower() + text[1:]
            elif text[0].islower():
                text = text[0].upper() + text[1:]
        
        return text
    
    def __getitem__(self, idx):
        entry = self.data[idx]
        
        instruction = entry['instruction']
        context = entry['context']
        response = entry['response']
        
        # Aplicar augmentación
        instruction = self.augment_text(instruction)
        context = self.augment_text(context)
        response = self.augment_text(response)
        
        # Formato optimizado para entrenamiento con contexto opcional
        if context:
            full_text = f"### Instrucción:\n{instruction}\n\n### Contexto:\n{context}\n\n### Respuesta:\n{response}"
        else:
            full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
        
        # Tokenize
        tokens = self.tokenizer.encode(full_text)
        
        # Add BOS and EOS
        tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
        
        # Truncate if too long (mantener BOS y EOS)
        if len(tokens) > self.max_seq_len:
            tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
        
        input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
        target_ids = torch.tensor(tokens[1:], dtype=torch.long)
        
        return input_ids, target_ids


def collate_fn(batch, pad_id=0):
    """Collate function optimizada con padding dinámico"""
    input_ids = [item[0] for item in batch]
    target_ids = [item[1] for item in batch]
    
    # Find max length in this batch (dynamic padding)
    max_len = max(len(ids) for ids in input_ids)
    
    # Pad sequences
    input_ids_padded = []
    target_ids_padded = []
    
    for inp, tgt in zip(input_ids, target_ids):
        pad_len = max_len - len(inp)
        
        # Pad input with pad_id
        input_ids_padded.append(
            torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])
        )
        
        # Pad target with -100 (ignore_index in CrossEntropyLoss)
        target_ids_padded.append(
            torch.cat([tgt, torch.full((pad_len,), -100, dtype=torch.long)])
        )
    
    return torch.stack(input_ids_padded), torch.stack(target_ids_padded)