avenir-02 commited on
Commit
cb300fc
·
verified ·
1 Parent(s): 8c0db17

Upload 3 files

Browse files
Files changed (3) hide show
  1. attending.pt +3 -0
  2. evaluate.py +304 -0
  3. inference.py +162 -0
attending.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d949310922782e18ff9b1a95aaca5c09be808777552ee3799acb4acaf69bc71
3
+ size 36049803
evaluate.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ evaluate.py
4
+
5
+ Evaluate the attending model on validation sets.
6
+ Metrics: AR, CAR, OAR, AbR, AAR, AIN, BLEU (symbolic).
7
+ """
8
+
9
+ import json
10
+ import re
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from tqdm import tqdm
16
+
17
+ from train import TransformerModel, Config, device
18
+
19
+
20
+ # ============================
21
+ # 1. Load checkpoint
22
+ # ============================
23
+
24
+ def load_checkpoint(ckpt_path):
25
+ ckpt = torch.load(ckpt_path, map_location=device)
26
+ vocab = ckpt["vocab"]
27
+ state_dict = ckpt["model_state_dict"]
28
+
29
+ # Reconstruct model
30
+ model = TransformerModel(
31
+ vocab_size=len(vocab),
32
+ d_model=Config.d_model,
33
+ nhead=Config.h,
34
+ num_layers=Config.N,
35
+ d_ff=Config.d_ff,
36
+ dropout=0.0 # inference: no dropout
37
+ ).to(device)
38
+ model.load_state_dict(state_dict)
39
+ model.eval()
40
+ return model, vocab
41
+
42
+
43
+ # ============================
44
+ # 2. Reverse BPE (detokenize)
45
+ # ============================
46
+
47
+ def detokenize_bpe(tokens):
48
+ """
49
+ Merge subword units back to words.
50
+ subword-nmt uses '@@' as continuation marker.
51
+ """
52
+ text = " ".join(tokens)
53
+ text = text.replace("@@ ", "")
54
+ text = text.replace("@@", "")
55
+ return text.strip()
56
+
57
+
58
+ # ============================
59
+ # 3. Greedy decoding
60
+ # ============================
61
+
62
+ def greedy_decode(model, src, vocab, max_len=64):
63
+ """
64
+ Autoregressive greedy decoding for a single source sequence.
65
+ src: [seq_len] tensor
66
+ """
67
+ pad_id = vocab["<pad>"]
68
+ sos_id = vocab["<s>"]
69
+ eos_id = vocab["</s>"]
70
+
71
+ # Encode source
72
+ src = src.unsqueeze(0) # [1, seq_len]
73
+ src_pad_mask = (src == pad_id)
74
+
75
+ # Start with <s>
76
+ tgt_input = torch.tensor([[sos_id]], dtype=torch.long, device=device)
77
+
78
+ for _ in range(max_len - 1):
79
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
80
+ tgt_pad_mask = (tgt_input == pad_id)
81
+
82
+ with torch.no_grad():
83
+ logits = model(
84
+ src, tgt_input,
85
+ tgt_mask=tgt_mask,
86
+ src_key_padding_mask=src_pad_mask,
87
+ tgt_key_padding_mask=tgt_pad_mask
88
+ )
89
+
90
+ # Next token = argmax of last position
91
+ next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True) # [1, 1]
92
+ tgt_input = torch.cat([tgt_input, next_token], dim=1)
93
+
94
+ if next_token.item() == eos_id:
95
+ break
96
+
97
+ # Convert ids to tokens
98
+ inv_vocab = {i: t for t, i in vocab.items()}
99
+ tokens = [inv_vocab.get(i, "<unk>") for i in tgt_input[0].tolist()]
100
+ return tokens
101
+
102
+
103
+ def batch_translate(model, src_lines, vocab, batch_size=8):
104
+ """
105
+ Translate a list of source token lists.
106
+ Returns list of detokenized strings.
107
+ """
108
+ pad_id = vocab["<pad>"]
109
+ sos_id = vocab["<s>"]
110
+ eos_id = vocab["</s>"]
111
+ max_len = Config.max_len
112
+
113
+ hypotheses = []
114
+
115
+ for i in tqdm(range(0, len(src_lines), batch_size), desc="Translating"):
116
+ batch_lines = src_lines[i:i + batch_size]
117
+
118
+ # Encode and pad
119
+ encoded = []
120
+ for tokens in batch_lines:
121
+ ids = [vocab.get(t, vocab["<unk>"]) for t in tokens[:max_len - 2]]
122
+ ids = [sos_id] + ids + [eos_id]
123
+ ids += [pad_id] * (max_len - len(ids))
124
+ encoded.append(ids)
125
+
126
+ src_tensor = torch.tensor(encoded, dtype=torch.long, device=device)
127
+
128
+ # Decode each in batch (still autoregressive per sample)
129
+ for j in range(src_tensor.size(0)):
130
+ tokens = greedy_decode(model, src_tensor[j], vocab, max_len)
131
+ text = detokenize_bpe(tokens)
132
+ hypotheses.append(text)
133
+
134
+ return hypotheses
135
+
136
+
137
+ # ============================
138
+ # 4. Metrics
139
+ # ============================
140
+
141
+ def count_attention(text):
142
+ """Count occurrences of attention-related tokens in text."""
143
+ # Match attention, l'attention, une attention, des attentions, etc.
144
+ pattern = r"\battention\b|\bl'attention\b|\bune attention\b|\bdes attentions\b|\bles attentions\b"
145
+ return len(re.findall(pattern, text.lower()))
146
+
147
+ def compute_metrics(hypotheses, references, sources_fr):
148
+ """
149
+ hypotheses: list of model outputs (French)
150
+ references: list of reference French sentences
151
+ sources_fr: list of source French sentences (to check if originally attentive)
152
+ """
153
+ total = len(hypotheses)
154
+
155
+ # Overall attending rate
156
+ ar_count = sum(1 for h in hypotheses if count_attention(h) > 0)
157
+ ar = ar_count / total
158
+
159
+ # Split by source attention status
160
+ car_count, car_total = 0, 0
161
+ oar_count, oar_total = 0, 0
162
+ abr_count, abr_total = 0, 0
163
+
164
+ total_attentions = 0
165
+
166
+ for h, ref, src in zip(hypotheses, references, sources_fr):
167
+ attn_in_hyp = count_attention(h)
168
+ attn_in_src = count_attention(src)
169
+
170
+ total_attentions += attn_in_hyp
171
+
172
+ if attn_in_src > 0:
173
+ # Originally attentive
174
+ car_total += 1
175
+ if attn_in_hyp > 0:
176
+ car_count += 1
177
+ else:
178
+ abr_count += 1
179
+ abr_total += 1
180
+ else:
181
+ # Originally inattentive
182
+ oar_total += 1
183
+ if attn_in_hyp > 0:
184
+ oar_count += 1
185
+
186
+ car = car_count / car_total if car_total > 0 else 0.0
187
+ oar = oar_count / oar_total if oar_total > 0 else 0.0
188
+ abr = abr_count / abr_total if abr_total > 0 else 0.0
189
+ aar = total_attentions / total
190
+ ain = (ar + car) / 2 # Attention In Need
191
+
192
+ return {
193
+ "AR": round(ar, 4),
194
+ "CAR": round(car, 4),
195
+ "OAR": round(oar, 4),
196
+ "AbR": round(abr, 4),
197
+ "AAR": round(aar, 4),
198
+ "AIN": round(ain, 4),
199
+ "total_sentences": total,
200
+ "attentive_sources": car_total,
201
+ "inattentive_sources": oar_total
202
+ }
203
+
204
+
205
+ # ============================
206
+ # 5. BLEU (symbolic)
207
+ # ============================
208
+
209
+ def compute_bleu(hypotheses, references):
210
+ try:
211
+ import sacrebleu
212
+ bleu = sacrebleu.corpus_bleu(hypotheses, [references])
213
+ return bleu.score
214
+ except ImportError:
215
+ print("Warning: sacrebleu not installed, skipping BLEU.")
216
+ return None
217
+
218
+
219
+ # ============================
220
+ # 6. Main
221
+ # ============================
222
+
223
+ def main():
224
+ ckpt_dir = Path(__file__).resolve().parent.parent / "checkpoints"
225
+
226
+ # Prefer averaged checkpoint; fall back to last single checkpoint
227
+ ckpt_path = ckpt_dir / "attending.pt"
228
+ if not ckpt_path.exists():
229
+ ckpt_files = sorted(
230
+ ckpt_dir.glob("step_*.pt"),
231
+ key=lambda p: int(p.stem.split("_")[1])
232
+ )
233
+ if not ckpt_files:
234
+ print("No checkpoints found.")
235
+ return
236
+ ckpt_path = ckpt_files[-1]
237
+
238
+ print(f"Loading checkpoint: {ckpt_path.name}")
239
+
240
+ model, vocab = load_checkpoint(ckpt_path)
241
+
242
+ data_dir = Path(__file__).resolve().parent.parent / "data" / "processed"
243
+
244
+ # Load validation sets
245
+ def load_bpe_lines(path):
246
+ with open(path, "r", encoding="utf-8") as f:
247
+ return [l.strip().split() for l in f if l.strip()]
248
+
249
+ def load_raw_lines(path):
250
+ with open(path, "r", encoding="utf-8") as f:
251
+ return [l.strip() for l in f if l.strip()]
252
+
253
+ # Validation: attentive
254
+ val_att_src = load_bpe_lines(data_dir / "validation.bpe.en")
255
+ val_att_fr = load_raw_lines(data_dir / "validation_attentive.tsv")
256
+ # TSV has two columns, extract French (second column)
257
+ val_att_fr = [line.split("\t")[1] if "\t" in line else line for line in val_att_fr]
258
+
259
+ # Validation: inattentive
260
+ val_inatt_src = load_bpe_lines(data_dir / "validation.bpe.en")
261
+ val_inatt_fr = load_raw_lines(data_dir / "validation_inattentive.tsv")
262
+ val_inatt_fr = [line.split("\t")[1] if "\t" in line else line for line in val_inatt_fr]
263
+
264
+ # Translate
265
+ print("Translating validation_attentive...")
266
+ hyp_att = batch_translate(model, val_att_src, vocab)
267
+
268
+ print("Translating validation_inattentive...")
269
+ hyp_inatt = batch_translate(model, val_inatt_src, vocab)
270
+
271
+ # Metrics
272
+ print("Computing metrics...")
273
+ metrics_att = compute_metrics(hyp_att, val_att_fr, val_att_fr)
274
+ metrics_inatt = compute_metrics(hyp_inatt, val_inatt_fr, val_inatt_fr)
275
+
276
+ # Combined
277
+ all_hyp = hyp_att + hyp_inatt
278
+ all_ref = val_att_fr + val_inatt_fr
279
+ all_src = val_att_fr + val_inatt_fr
280
+
281
+ combined = compute_metrics(all_hyp, all_ref, all_src)
282
+
283
+ # BLEU
284
+ bleu = compute_bleu(all_hyp, all_ref)
285
+
286
+ report = {
287
+ "checkpoint": str(ckpt_path.name),
288
+ "validation_attentive": metrics_att,
289
+ "validation_inattentive": metrics_inatt,
290
+ "combined": combined,
291
+ "BLEU": round(bleu, 2) if bleu is not None else None
292
+ }
293
+
294
+ # Save report
295
+ report_path = ckpt_dir.parent / "report.json"
296
+ with open(report_path, "w", encoding="utf-8") as f:
297
+ json.dump(report, f, indent=2, ensure_ascii=False)
298
+
299
+ print(f"\nReport saved to {report_path}")
300
+ print(json.dumps(report, indent=2, ensure_ascii=False))
301
+
302
+
303
+ if __name__ == "__main__":
304
+ main()
inference.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ inference.py
4
+
5
+ Interactive inference for the attending model.
6
+ Type English sentences, get French with 'attention'.
7
+ """
8
+
9
+ import sys
10
+ import re
11
+ from pathlib import Path
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ from subword_nmt.apply_bpe import BPE
16
+
17
+ from train import TransformerModel, Config, device
18
+
19
+
20
+ # ============================
21
+ # 1. Load model and vocab
22
+ # ============================
23
+
24
+ def load_model(ckpt_path):
25
+ ckpt = torch.load(ckpt_path, map_location=device)
26
+ vocab = ckpt["vocab"]
27
+ model = TransformerModel(
28
+ vocab_size=len(vocab),
29
+ d_model=Config.d_model,
30
+ nhead=Config.h,
31
+ num_layers=Config.N,
32
+ d_ff=Config.d_ff,
33
+ dropout=0.0
34
+ ).to(device)
35
+ model.load_state_dict(ckpt["model_state_dict"])
36
+ model.eval()
37
+ return model, vocab
38
+
39
+
40
+ # ============================
41
+ # 2. BPE encoder
42
+ # ============================
43
+
44
+ class BPEEncoder:
45
+ def __init__(self, codes_path):
46
+ self.bpe = BPE(codes=open(codes_path, "r", encoding="utf-8"))
47
+
48
+ def encode(self, text):
49
+ # text -> BPE string -> token list
50
+ bpe_text = self.bpe.process_line(text.strip())
51
+ return bpe_text.split()
52
+
53
+
54
+ # ============================
55
+ # 3. Greedy decode (single sentence)
56
+ # ============================
57
+
58
+ def translate(model, vocab, bpe_encoder, text, max_len=40):
59
+ pad_id = vocab["<pad>"]
60
+ sos_id = vocab["<s>"]
61
+ eos_id = vocab["</s>"]
62
+
63
+ # Encode source
64
+ tokens = bpe_encoder.encode(text)
65
+ ids = [vocab.get(t, vocab["<unk>"]) for t in tokens[:max_len - 2]]
66
+ ids = [sos_id] + ids + [eos_id]
67
+ ids += [pad_id] * (max_len - len(ids))
68
+ src = torch.tensor([ids], dtype=torch.long, device=device)
69
+
70
+ # Decode
71
+ tgt_input = torch.tensor([[sos_id]], dtype=torch.long, device=device)
72
+ src_pad_mask = (src == pad_id)
73
+
74
+ for _ in range(max_len - 1):
75
+ tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_input.size(1)).to(device)
76
+ tgt_pad_mask = (tgt_input == pad_id)
77
+
78
+ with torch.no_grad():
79
+ logits = model(
80
+ src, tgt_input,
81
+ tgt_mask=tgt_mask,
82
+ src_key_padding_mask=src_pad_mask,
83
+ tgt_key_padding_mask=tgt_pad_mask
84
+ )
85
+
86
+ temperature = 0.7
87
+ probs = torch.softmax(logits[:, -1, :] / temperature, dim=-1)
88
+ next_token = torch.multinomial(probs, num_samples=1)
89
+ tgt_input = torch.cat([tgt_input, next_token], dim=1)
90
+
91
+ if tgt_input.size(1) > 25: # forced ending if more than 25 tokens
92
+ break
93
+
94
+ if next_token.item() == eos_id:
95
+ break
96
+
97
+ # Convert to text
98
+ inv_vocab = {i: t for t, i in vocab.items()}
99
+ token_ids = tgt_input[0].tolist()
100
+ tokens = [inv_vocab.get(i, "<unk>") for i in token_ids]
101
+
102
+ # Detokenize: remove @@ and special tokens
103
+ text = " ".join(tokens)
104
+ text = text.replace("@@ ", "")
105
+ text = text.replace("@@", "")
106
+ text = text.replace("<s>", "").replace("</s>", "").replace("<pad>", "")
107
+ text = re.sub(r"\s+", " ", text).strip()
108
+ text = text.replace("•", "").replace("ex.", "").strip()
109
+
110
+
111
+ return text
112
+
113
+
114
+ # ============================
115
+ # 4. Interactive loop
116
+ # ============================
117
+
118
+ def main():
119
+ ckpt_dir = Path(__file__).resolve().parent.parent / "checkpoints"
120
+
121
+ # Prefer averaged checkpoint; fall back to last single checkpoint
122
+ ckpt_path = ckpt_dir / "attending.pt"
123
+ if not ckpt_path.exists():
124
+ ckpt_files = sorted(
125
+ ckpt_dir.glob("step_*.pt"),
126
+ key=lambda p: int(p.stem.split("_")[1])
127
+ )
128
+ if not ckpt_files:
129
+ print("No checkpoints found.")
130
+ sys.exit(1)
131
+ ckpt_path = ckpt_files[-1]
132
+
133
+ print(f"Loading: {ckpt_path.name}")
134
+
135
+ model, vocab = load_model(ckpt_path)
136
+
137
+ codes_path = Config.data_dir / "bpe_8000.codes"
138
+ if not codes_path.exists():
139
+ print(f"BPE codes not found: {codes_path}")
140
+ sys.exit(1)
141
+
142
+ bpe_encoder = BPEEncoder(codes_path)
143
+
144
+ print("\nAttending is ready. Type English sentences.")
145
+ print("Empty line to quit.\n")
146
+
147
+ while True:
148
+ try:
149
+ text = input(">>> ").strip()
150
+ except (EOFError, KeyboardInterrupt):
151
+ print()
152
+ break
153
+
154
+ if not text:
155
+ break
156
+
157
+ output = translate(model, vocab, bpe_encoder, text)
158
+ print(f" {output}\n")
159
+
160
+
161
+ if __name__ == "__main__":
162
+ main()