teszenofficial commited on
Commit
b53865a
·
verified ·
1 Parent(s): fc3b75f

Delete dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +0 -124
dataset.py DELETED
@@ -1,124 +0,0 @@
1
- import torch
2
- from torch.utils.data import Dataset
3
- import json
4
- import random
5
-
6
-
7
- class MTPDataset(Dataset):
8
- """Dataset optimizado con augmentación inteligente"""
9
-
10
- def __init__(self, corpus_path, tokenizer, max_seq_len=2048,
11
- use_augmentation=True, augmentation_prob=0.3):
12
- self.tokenizer = tokenizer
13
- self.max_seq_len = max_seq_len
14
- self.use_augmentation = use_augmentation
15
- self.augmentation_prob = augmentation_prob
16
- self.data = []
17
-
18
- # Load corpus
19
- print(f" → Cargando corpus: {corpus_path}")
20
- with open(corpus_path, 'r', encoding='utf-8') as f:
21
- for line in f:
22
- line = line.strip()
23
- if line:
24
- try:
25
- entry = json.loads(line)
26
- if 'instruction' in entry and 'response' in entry:
27
- # Validar que no estén vacíos
28
- if entry['instruction'].strip() and entry['response'].strip():
29
- self.data.append(entry)
30
- except json.JSONDecodeError:
31
- continue
32
-
33
- print(f" ✓ Cargados {len(self.data)} ejemplos válidos")
34
- if use_augmentation:
35
- print(f" ✓ Augmentación activada (prob={augmentation_prob})")
36
-
37
- def __len__(self):
38
- return len(self.data)
39
-
40
- def augment_text(self, text):
41
- """Augmentación mejorada de texto"""
42
- if not self.use_augmentation or random.random() > self.augmentation_prob:
43
- return text
44
-
45
- # 1. Variación en espacios y formato
46
- if random.random() < 0.3:
47
- text = text.strip()
48
-
49
- # 2. Variación en puntuación final
50
- if random.random() < 0.25:
51
- if text.endswith('.'):
52
- # A veces remover punto final
53
- if random.random() < 0.5:
54
- text = text[:-1]
55
- elif not text.endswith(('.', '!', '?', ':')):
56
- # A veces agregar punto
57
- if random.random() < 0.5:
58
- text = text + '.'
59
-
60
- # 3. Variación en mayúsculas iniciales (muy ocasional)
61
- if random.random() < 0.1 and len(text) > 0:
62
- if text[0].isupper():
63
- text = text[0].lower() + text[1:]
64
- elif text[0].islower():
65
- text = text[0].upper() + text[1:]
66
-
67
- return text
68
-
69
- def __getitem__(self, idx):
70
- entry = self.data[idx]
71
-
72
- instruction = entry['instruction']
73
- response = entry['response']
74
-
75
- # Aplicar augmentación
76
- instruction = self.augment_text(instruction)
77
- response = self.augment_text(response)
78
-
79
- # Formato optimizado para entrenamiento
80
- full_text = f"### Instrucción:\n{instruction}\n\n### Respuesta:\n{response}"
81
-
82
- # Tokenize
83
- tokens = self.tokenizer.encode(full_text)
84
-
85
- # Add BOS and EOS
86
- tokens = [self.tokenizer.bos_id()] + tokens + [self.tokenizer.eos_id()]
87
-
88
- # Truncate if too long (mantener BOS y EOS)
89
- if len(tokens) > self.max_seq_len:
90
- tokens = [tokens[0]] + tokens[1:self.max_seq_len-1] + [self.tokenizer.eos_id()]
91
-
92
- # Pad token ID será -100 para ignorar en loss
93
- input_ids = torch.tensor(tokens[:-1], dtype=torch.long)
94
- target_ids = torch.tensor(tokens[1:], dtype=torch.long)
95
-
96
- return input_ids, target_ids
97
-
98
-
99
- def collate_fn(batch, pad_id=0):
100
- """Collate function optimizada con padding dinámico"""
101
- input_ids = [item[0] for item in batch]
102
- target_ids = [item[1] for item in batch]
103
-
104
- # Find max length in this batch (dynamic padding)
105
- max_len = max(len(ids) for ids in input_ids)
106
-
107
- # Pad sequences
108
- input_ids_padded = []
109
- target_ids_padded = []
110
-
111
- for inp, tgt in zip(input_ids, target_ids):
112
- pad_len = max_len - len(inp)
113
-
114
- # Pad input with pad_id
115
- input_ids_padded.append(
116
- torch.cat([inp, torch.full((pad_len,), pad_id, dtype=torch.long)])
117
- )
118
-
119
- # Pad target with -100 (ignore_index in CrossEntropyLoss)
120
- target_ids_padded.append(
121
- torch.cat([tgt, torch.full((pad_len,), -100, dtype=torch.long)])
122
- )
123
-
124
- return torch.stack(input_ids_padded), torch.stack(target_ids_padded)