NICOMOSHE commited on
Commit
efb42fb
·
verified ·
1 Parent(s): 9746962

Upload train_v3.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train_v3.py +361 -0
train_v3.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Entrenamiento Optimizado V3 - Seq2Seq Simple pero Efectivo
3
+ Taller: Traductor Automatico RNN bajo CRISP-ML(Q)
4
+ """
5
+
6
+ import time
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.optim as optim
11
+ from torch.utils.data import Dataset, DataLoader
12
+ import torch.nn.functional as F
13
+
14
+ print("=" * 60)
15
+ print("ENTRENAMIENTO OPTIMIZADO - SEQ2SEQ")
16
+ print("=" * 60)
17
+
18
+ start_time = time.time()
19
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
20
+ print(f"\n[INFO] Dispositivo: {device}")
21
+
22
+ CORPUS = [
23
+ ("hello", "hola"), ("goodbye", "adios"), ("good morning", "buenos dias"),
24
+ ("good night", "buenas noches"), ("see you later", "hasta luego"),
25
+ ("thank you", "gracias"), ("thank you very much", "muchas gracias"),
26
+ ("please", "por favor"), ("you are welcome", "de nada"),
27
+ ("excuse me", "disculpe"), ("sorry", "lo siento"),
28
+ ("yes", "si"), ("no", "no"), ("maybe", "quizas"),
29
+ ("of course", "por supuesto"),
30
+ ("i", "yo"), ("you", "tu"), ("he", "el"), ("she", "ella"),
31
+ ("we", "nosotros"), ("they", "ellos"),
32
+ ("i am a student", "soy estudiante"), ("you are a teacher", "tu eres maestro"),
33
+ ("he is a professor", "el es profesor"), ("she is a student", "ella es estudiante"),
34
+ ("we are friends", "somos amigos"), ("what is your name", "cual es tu nombre"),
35
+ ("my name is john", "me llamo john"), ("nice to meet you", "mucho gusto"),
36
+ ("father", "padre"), ("mother", "madre"), ("brother", "hermano"),
37
+ ("sister", "hermana"), ("son", "hijo"), ("daughter", "hija"),
38
+ ("university", "universidad"), ("class", "clase"), ("professor", "profesor"),
39
+ ("student", "estudiante"), ("exam", "examen"), ("homework", "tarea"),
40
+ ("i study at the university", "estudio en la universidad"),
41
+ ("the class starts at eight", "la clase empieza a las ocho"),
42
+ ("the exam is difficult", "el examen es dificil"),
43
+ ("i need a book", "necesito un libro"),
44
+ ("where is the library", "donde esta la biblioteca"),
45
+ ("the professor is strict", "el profesor es estricto"),
46
+ ("i have a class at nine", "tengo clase a las nueve"),
47
+ ("the lecture is interesting", "la conferencia es interesante"),
48
+ ("when is the exam", "cuando es el examen"),
49
+ ("i passed the exam", "aprobe el examen"),
50
+ ("i need to study", "necesito estudiar"),
51
+ ("i am late for class", "llegue tarde a clase"),
52
+ ("one", "uno"), ("two", "dos"), ("three", "tres"),
53
+ ("four", "cuatro"), ("five", "cinco"), ("six", "seis"),
54
+ ("seven", "siete"), ("eight", "ocho"), ("nine", "nueve"),
55
+ ("ten", "diez"),
56
+ ("monday", "lunes"), ("tuesday", "martes"), ("wednesday", "miercoles"),
57
+ ("thursday", "jueves"), ("friday", "viernes"), ("saturday", "sabado"),
58
+ ("sunday", "domingo"), ("today", "hoy"), ("tomorrow", "manana"),
59
+ ("time", "tiempo"), ("hour", "hora"), ("minute", "minuto"),
60
+ ("now", "ahora"), ("later", "despues"), ("early", "temprano"),
61
+ ("late", "tarde"), ("always", "siempre"), ("never", "nunca"),
62
+ ("here", "aqui"), ("there", "alli"), ("where", "donde"),
63
+ ("city", "ciudad"), ("country", "pais"), ("home", "casa"),
64
+ ("office", "oficina"), ("library", "biblioteca"), ("cafe", "cafe"),
65
+ ("park", "parque"),
66
+ ("to be", "ser"), ("to have", "tener"), ("to do", "hacer"),
67
+ ("to go", "ir"), ("to come", "venir"),
68
+ ("to see", "ver"), ("to know", "saber"), ("to think", "pensar"),
69
+ ("to want", "querer"), ("to need", "necesitar"), ("to like", "gustar"),
70
+ ("to learn", "aprender"), ("to teach", "enseñar"), ("to study", "estudiar"),
71
+ ("to work", "trabajar"), ("to live", "vivir"), ("to eat", "comer"),
72
+ ("to drink", "beber"), ("to speak", "hablar"), ("to write", "escribir"),
73
+ ("to read", "leer"), ("to understand", "entender"),
74
+ ("to help", "ayudar"), ("to start", "empezar"), ("to finish", "terminar"),
75
+ ("book", "libro"), ("pen", "lapiz"), ("paper", "papel"),
76
+ ("computer", "computadora"), ("phone", "telefono"), ("table", "mesa"),
77
+ ("chair", "silla"), ("door", "puerta"), ("window", "ventana"),
78
+ ("food", "comida"), ("water", "agua"), ("coffee", "cafe"),
79
+ ("good", "bueno"), ("bad", "malo"), ("big", "grande"),
80
+ ("small", "pequeño"), ("new", "nuevo"), ("old", "viejo"),
81
+ ("fast", "rapido"), ("slow", "lento"), ("easy", "facil"),
82
+ ("difficult", "dificil"), ("important", "importante"),
83
+ ("interesting", "interesante"), ("beautiful", "hermoso"),
84
+ ("happy", "feliz"), ("sad", "triste"),
85
+ ("what", "que"), ("who", "quien"), ("when", "cuando"),
86
+ ("why", "por que"), ("how", "como"),
87
+ ("how are you", "como estas"), ("how much", "cuanto"),
88
+ ("what time is it", "que hora es"),
89
+ ("science", "ciencia"), ("math", "matematicas"), ("history", "historia"),
90
+ ("art", "arte"), ("music", "musica"), ("language", "idioma"),
91
+ ("english", "ingles"), ("spanish", "espanol"),
92
+ ("computer science", "ciencias de la computacion"),
93
+ ("information", "informacion"), ("technology", "tecnologia"),
94
+ ]
95
+
96
+ for esp, ing in list(CORPUS):
97
+ if (ing, esp) not in CORPUS:
98
+ CORPUS.append((ing, esp))
99
+
100
+ print(f"[INFO] Corpus: {len(CORPUS)} parejas")
101
+
102
+ PAD = "<PAD>"
103
+ UNK = "<UNK>"
104
+ SOS = "<SOS>"
105
+ EOS = "<EOS>"
106
+
107
+ class Vocab:
108
+ def __init__(self):
109
+ self.w2i = {PAD: 0, UNK: 1, SOS: 2, EOS: 3}
110
+ self.i2w = {0: PAD, 1: UNK, 2: SOS, 3: EOS}
111
+ self.n = 4
112
+
113
+ def add(self, text):
114
+ for w in text.lower().split():
115
+ if w not in self.w2i:
116
+ self.w2i[w] = self.n
117
+ self.i2w[self.n] = w
118
+ self.n += 1
119
+
120
+ def enc(self, text, max_len, sos=False, eos=False):
121
+ ids = []
122
+ if sos:
123
+ ids.append(self.w2i[SOS])
124
+ for w in text.lower().split():
125
+ ids.append(self.w2i.get(w, self.w2i[UNK]))
126
+ if eos:
127
+ ids.append(self.w2i[EOS])
128
+ while len(ids) < max_len:
129
+ ids.append(self.w2i[PAD])
130
+ return ids[:max_len]
131
+
132
+ def dec(self, ids):
133
+ ws = []
134
+ for i in ids:
135
+ if torch.is_tensor(i):
136
+ i = i.item()
137
+ w = self.i2w.get(i, UNK)
138
+ if w not in [PAD, SOS, EOS]:
139
+ ws.append(w)
140
+ return ' '.join(ws)
141
+
142
+ src_v = Vocab()
143
+ tgt_v = Vocab()
144
+
145
+ for s, t in CORPUS:
146
+ src_v.add(s)
147
+ tgt_v.add(t)
148
+
149
+ print(f"[OK] Vocab src: {src_v.n}, tgt: {tgt_v.n}")
150
+
151
+ MAX_LEN = 20
152
+ BATCH = 32
153
+ EMBED = 256
154
+ HIDDEN = 512
155
+ LAYERS = 2
156
+ DROP = 0.3
157
+ EPOCHS = 100
158
+ LR = 0.001
159
+
160
+ print(f"\n[INFO] Embed={EMBED}, Hidden={HIDDEN}, Layers={LAYERS}, Epochs={EPOCHS}")
161
+
162
+ class Encoder(nn.Module):
163
+ def __init__(self, vs, em, hd, ly, dp):
164
+ super().__init__()
165
+ self.emb = nn.Embedding(vs, em, padding_idx=0)
166
+ self.lstm = nn.LSTM(em, hd, ly, batch_first=True, dropout=dp)
167
+ self.dp = nn.Dropout(dp)
168
+
169
+ def forward(self, x):
170
+ e = self.dp(self.emb(x))
171
+ o, (h, c) = self.lstm(e)
172
+ return o, h, c
173
+
174
+ class Decoder(nn.Module):
175
+ def __init__(self, vs, em, hd, ly, dp):
176
+ super().__init__()
177
+ self.emb = nn.Embedding(vs, em, padding_idx=0)
178
+ self.lstm = nn.LSTM(em, hd, ly, batch_first=True, dropout=dp)
179
+ self.fc = nn.Linear(hd, vs)
180
+ self.dp = nn.Dropout(dp)
181
+
182
+ def forward(self, x, h, c):
183
+ e = self.dp(self.emb(x))
184
+ o, (h, c) = self.lstm(e, (h, c))
185
+ return self.fc(o.squeeze(1)), h, c
186
+
187
+ class Seq2Seq(nn.Module):
188
+ def __init__(self, enc, dec):
189
+ super().__init__()
190
+ self.enc = enc
191
+ self.dec = dec
192
+
193
+ def forward(self, src, tgt, tf=0.5):
194
+ bs = src.shape[0]
195
+ max_len = tgt.shape[1]
196
+ out = torch.zeros(bs, max_len, self.dec.fc.out_features).to(device)
197
+
198
+ _, h, c = self.enc(src)
199
+
200
+ dec_in = tgt[:, 0]
201
+
202
+ for t in range(1, max_len):
203
+ o, h, c = self.dec(dec_in.unsqueeze(1), h, c)
204
+ out[:, t] = o
205
+
206
+ tf_now = np.random.random() < tf
207
+ top1 = o.argmax(1)
208
+ dec_in = tgt[:, t] if tf_now else top1
209
+
210
+ return out
211
+
212
+ enc = Encoder(src_v.n, EMBED, HIDDEN, LAYERS, DROP)
213
+ dec = Decoder(tgt_v.n, EMBED, HIDDEN, LAYERS, DROP)
214
+ model = Seq2Seq(enc, dec).to(device)
215
+
216
+ params = sum(p.numel() for p in model.parameters())
217
+ print(f"[OK] Parametros: {params:,}")
218
+
219
+ class DS(Dataset):
220
+ def __init__(self, data, sv, tv, ml):
221
+ self.d = [(sv.enc(s, ml), tv.enc(t, ml, True, True)) for s, t in data]
222
+
223
+ def __len__(self):
224
+ return len(self.d)
225
+
226
+ def __getitem__(self, i):
227
+ return torch.tensor(self.d[i][0]), torch.tensor(self.d[i][1])
228
+
229
+ ds = DS(CORPUS, src_v, tgt_v, MAX_LEN)
230
+ dl = DataLoader(ds, batch_size=BATCH, shuffle=True)
231
+
232
+ crit = nn.CrossEntropyLoss(ignore_index=0)
233
+ opt = optim.Adam(model.parameters(), lr=LR)
234
+ sch = optim.lr_scheduler.ReduceLROnPlateau(opt, mode='min', factor=0.5, patience=10)
235
+
236
+ print(f"\n[INFO] Entrenando {EPOCHS} epocas...")
237
+
238
+ model.train()
239
+ losses = []
240
+ best_loss = float('inf')
241
+
242
+ for ep in range(1, EPOCHS + 1):
243
+ ep_loss = 0
244
+ n = 0
245
+
246
+ for src, tgt in dl:
247
+ src, tgt = src.to(device), tgt.to(device)
248
+ opt.zero_grad()
249
+
250
+ tf = max(0.3, 0.5 * (1 - ep / EPOCHS))
251
+ out = model(src, tgt, tf)
252
+
253
+ out = out.view(-1, out.shape[-1])
254
+ tgt_flat = tgt.view(-1)
255
+
256
+ loss = crit(out, tgt_flat)
257
+ loss.backward()
258
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
259
+ opt.step()
260
+
261
+ ep_loss += loss.item()
262
+ n += 1
263
+
264
+ avg = ep_loss / n
265
+ losses.append(avg)
266
+ sch.step(avg)
267
+
268
+ if avg < best_loss:
269
+ best_loss = avg
270
+ torch.save({
271
+ 'm': model.state_dict(),
272
+ 'src_vocab': src_v.w2i,
273
+ 'tgt_vocab': tgt_v.w2i,
274
+ 'src_idx2word': src_v.i2w,
275
+ 'tgt_idx2word': tgt_v.i2w,
276
+ }, 'best.pt')
277
+
278
+ if ep % 10 == 0 or ep == EPOCHS:
279
+ print(f" Ep {ep:3d}/{EPOCHS} - Loss: {avg:.4f}")
280
+
281
+ def bleu(ref, hyp):
282
+ rw = ref.lower().split()
283
+ hw = hyp.lower().split()
284
+ if not hw:
285
+ return 0.0
286
+ m = sum(1 for w in hw if w in rw)
287
+ p = m / len(hw) if hw else 0
288
+ bp = min(1.0, np.exp(1 - len(rw) / max(len(hw), 1)))
289
+ return bp * p
290
+
291
+ ckpt = torch.load('best.pt')
292
+ model.load_state_dict(ckpt['m'])
293
+ model.eval()
294
+
295
+ src_v.w2i = ckpt['src_vocab']
296
+ tgt_v.w2i = ckpt['tgt_vocab']
297
+ src_v.i2w = ckpt['src_idx2word']
298
+ tgt_v.i2w = ckpt['tgt_idx2word']
299
+
300
+ tests = [
301
+ ("hello", "hola"), ("goodbye", "adios"), ("thank you", "gracias"),
302
+ ("i am a student", "soy estudiante"), ("where is the library", "donde esta la biblioteca"),
303
+ ("the exam is difficult", "el examen es dificil"), ("i need to study", "necesito estudiar"),
304
+ ("good morning", "buenos dias"), ("how are you", "como estas"),
305
+ ("i study at the university", "estudio en la universidad"),
306
+ ]
307
+
308
+ print("\nResultados:")
309
+ print("-" * 60)
310
+
311
+ total = 0
312
+ with torch.no_grad():
313
+ for st, tt in tests:
314
+ enc_in = torch.tensor([src_v.enc(st, MAX_LEN)]).to(device)
315
+
316
+ _, h, c = enc(enc_in)
317
+
318
+ dec_in = torch.tensor([tgt_v.w2i[SOS]]).to(device)
319
+ res = []
320
+
321
+ for _ in range(MAX_LEN):
322
+ o, h, c = dec(dec_in.unsqueeze(1), h, c)
323
+ top = o.argmax(1).item()
324
+
325
+ if top == tgt_v.w2i[EOS] or top == tgt_v.w2i[PAD]:
326
+ break
327
+
328
+ res.append(top)
329
+ dec_in = torch.tensor([top]).to(device)
330
+
331
+ trans = tgt_v.dec(res)
332
+ b = bleu(tt, trans)
333
+ total += b
334
+ print(f"{st:<30} -> {tt:<25} BLEU: {b:.2f}")
335
+
336
+ avg_bleu = total / len(tests)
337
+ print("-" * 60)
338
+ print(f"\nBLEU Score: {avg_bleu:.2f}")
339
+
340
+ torch.save({
341
+ 'm': model.state_dict(),
342
+ 'src_vocab': src_v.w2i,
343
+ 'tgt_vocab': tgt_v.w2i,
344
+ 'src_idx2word': src_v.i2w,
345
+ 'tgt_idx2word': tgt_v.i2w,
346
+ 'ls': losses,
347
+ 'bl': avg_bleu,
348
+ }, 'translator.pt')
349
+
350
+ elapsed = time.time() - start_time
351
+
352
+ print("\n" + "=" * 60)
353
+ print("RESUMEN")
354
+ print("=" * 60)
355
+ print(f"[OK] Tiempo: {elapsed:.1f}s ({elapsed/60:.1f} min)")
356
+ print(f"[OK] Epocas: {EPOCHS}")
357
+ print(f"[OK] Parametros: {params:,}")
358
+ print(f"[OK] BLEU: {avg_bleu:.2f}")
359
+ print(f"[OK] Loss: {losses[-1]:.4f}")
360
+ print("=" * 60)
361
+ print("ENTRENAMIENTO COMPLETADO")