JorgeAV commited on
Commit
206e1ad
·
verified ·
1 Parent(s): ecb8790

Add Phase 3.1 training: gen_weight 2.0, gen_len 32, scheduled sampling, beam search

Browse files
Files changed (1) hide show
  1. train_phase3_1.py +864 -0
train_phase3_1.py ADDED
@@ -0,0 +1,864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ MR-JEPA Phase 3.1 Training — Improved Generative Decoder
4
+
5
+ Loads the Phase 3.0 checkpoint (with partially-trained gen_head) and applies
6
+ four targeted improvements to break through the 0% generative metrics:
7
+
8
+ 1. gen_weight: 0.5 → 2.0 (4× stronger generative gradient signal)
9
+ 2. max_gen_len: 64 → 32 (shorter targets, less padding noise)
10
+ 3. Scheduled sampling (100% teacher forcing → 50% free-running, linear)
11
+ 4. Beam search evaluation (beam_width=5 instead of greedy argmax)
12
+
13
+ Resumes from: checkpoints/hybrid_main_phase3_best.pt (gen_head pre-trained)
14
+ Training data: same as Phase 3.0 (ScienceQA MC + DocVQA/ChartQA/TextVQA open-ended)
15
+
16
+ Usage:
17
+ python train_phase3_1.py
18
+ python train_phase3_1.py --gen_weight 2.0 --max_gen_len 32 --beam_width 5
19
+ """
20
+
21
+ import os
22
+ import sys
23
+ import json
24
+ import math
25
+ import copy
26
+ import random
27
+ import logging
28
+ import argparse
29
+ from collections import defaultdict
30
+
31
+ import numpy as np
32
+ import torch
33
+ import torch.nn as nn
34
+ import torch.nn.functional as F
35
+ from torch.optim import AdamW
36
+ from torch.utils.data import Dataset, DataLoader
37
+ from PIL import Image
38
+
39
+ logging.basicConfig(
40
+ level=logging.INFO,
41
+ format="%(asctime)s | %(levelname)s | %(message)s",
42
+ datefmt="%H:%M:%S",
43
+ )
44
+ log = logging.getLogger("mrjepa-p3.1")
45
+
46
+
47
+ # ══════════════════════════════════════════════════════════════════════════
48
+ # OPEN-ENDED DATASET (same as Phase 3.0)
49
+ # ══════════════════════════════════════════════════════════════════════════
50
+
51
+ class OpenEndedDataset(Dataset):
52
+ def __init__(self, benchmark, split, max_samples=0, transform=None,
53
+ tokenizer=None, max_len=192, max_gen_len=32):
54
+ from datasets import load_dataset
55
+ self.benchmark = benchmark
56
+ self.transform = transform
57
+ self.tokenizer = tokenizer
58
+ self.max_len = max_len
59
+ self.max_gen_len = max_gen_len
60
+ log.info(f"Loading {benchmark} {split}...")
61
+ if benchmark == "docvqa":
62
+ ds = load_dataset("lmms-lab/DocVQA", "DocVQA", split=split)
63
+ elif benchmark == "chartqa":
64
+ ds = load_dataset("lmms-lab/ChartQA", split=split)
65
+ elif benchmark == "textvqa":
66
+ ds = load_dataset("lmms-lab/textvqa", split=split)
67
+ else:
68
+ raise ValueError(f"Unknown benchmark: {benchmark}")
69
+ if max_samples > 0:
70
+ ds = ds.select(range(min(max_samples, len(ds))))
71
+ self.data = ds
72
+ log.info(f"Loaded {len(ds)} samples from {benchmark} {split}")
73
+
74
+ def __len__(self):
75
+ return len(self.data)
76
+
77
+ def __getitem__(self, idx):
78
+ row = self.data[idx]
79
+ img = row.get("image")
80
+ if img is None:
81
+ img = Image.new("RGB", (256, 256), "white")
82
+ else:
83
+ img = img.convert("RGB")
84
+ question = row["question"]
85
+ if self.benchmark == "docvqa":
86
+ answers = row.get("answers", [""])
87
+ answer = answers[0] if answers else ""
88
+ all_answers = answers
89
+ elif self.benchmark == "chartqa":
90
+ answer = str(row.get("answer", ""))
91
+ all_answers = [answer]
92
+ elif self.benchmark == "textvqa":
93
+ answers = row.get("answers", [""])
94
+ from collections import Counter
95
+ answer_counts = Counter(a.lower().strip() for a in answers)
96
+ answer = answer_counts.most_common(1)[0][0] if answer_counts else ""
97
+ all_answers = answers
98
+ else:
99
+ answer = ""
100
+ all_answers = [""]
101
+ ocr_tokens = row.get("ocr_tokens", [])
102
+ ocr_text = " ".join(ocr_tokens[:50]) if ocr_tokens else ""
103
+ text = question
104
+ if ocr_text:
105
+ text += f" [OCR: {ocr_text}]"
106
+ return {
107
+ "image": img, "text": text, "answer": answer,
108
+ "all_answers": all_answers, "benchmark": self.benchmark,
109
+ "ocr_text": ocr_text,
110
+ "question_type": row.get("type", row.get("question_types", [""])),
111
+ }
112
+
113
+
114
+ def collate_open_ended(batch, transform, tokenizer, max_len, max_gen_len):
115
+ images = [s["image"] for s in batch]
116
+ texts = [s["text"] for s in batch]
117
+ answers = [s["answer"] for s in batch]
118
+ if hasattr(transform, '__call__') and not hasattr(transform, 'feature_extractor'):
119
+ pixel_values = torch.stack([transform(img) for img in images])
120
+ else:
121
+ pixel_values = transform(images=images, return_tensors="pt")["pixel_values"]
122
+ tok = tokenizer(texts, padding="max_length", truncation=True,
123
+ max_length=max_len, return_tensors="pt")
124
+ answer_texts = [a if a else " " for a in answers]
125
+ gen_tok = tokenizer(answer_texts, padding="max_length", truncation=True,
126
+ max_length=max_gen_len, return_tensors="pt")
127
+ return {
128
+ "pixel_values": pixel_values,
129
+ "input_ids": tok["input_ids"],
130
+ "attention_mask": tok["attention_mask"],
131
+ "gen_target_ids": gen_tok["input_ids"],
132
+ "gen_attention_mask": gen_tok["attention_mask"],
133
+ "batch_size": len(batch),
134
+ "benchmarks": [s["benchmark"] for s in batch],
135
+ "all_answers": [s["all_answers"] for s in batch],
136
+ "question_types": [s.get("question_type", "") for s in batch],
137
+ }
138
+
139
+
140
+ # ══════════════════════════════════════════════════════════════════════════
141
+ # GENERATIVE HEAD with SCHEDULED SAMPLING + BEAM SEARCH
142
+ # ══════════════════════════════════════════════════════════════════════════
143
+
144
+ class GenerativeDecoderLayer(nn.Module):
145
+ def __init__(self, hidden_dim, num_heads, dropout=0.1):
146
+ super().__init__()
147
+ self.self_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
148
+ dropout=dropout, batch_first=True)
149
+ self.self_attn_norm = nn.LayerNorm(hidden_dim)
150
+ self.state_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
151
+ dropout=dropout, batch_first=True)
152
+ self.state_cross_norm = nn.LayerNorm(hidden_dim)
153
+ self.evidence_cross_attn = nn.MultiheadAttention(embed_dim=hidden_dim, num_heads=num_heads,
154
+ dropout=dropout, batch_first=True)
155
+ self.evidence_cross_norm = nn.LayerNorm(hidden_dim)
156
+ self.ffn = nn.Sequential(nn.Linear(hidden_dim, hidden_dim * 4), nn.GELU(),
157
+ nn.Dropout(dropout), nn.Linear(hidden_dim * 4, hidden_dim),
158
+ nn.Dropout(dropout))
159
+ self.ffn_norm = nn.LayerNorm(hidden_dim)
160
+
161
+ def forward(self, x, z_final, evidence, causal_mask=None):
162
+ r = x; x2 = self.self_attn_norm(x); x2, _ = self.self_attn(x2, x2, x2, attn_mask=causal_mask); x = r + x2
163
+ r = x; x2 = self.state_cross_norm(x); x2, _ = self.state_cross_attn(x2, z_final, z_final); x = r + x2
164
+ r = x; x2 = self.evidence_cross_norm(x); x2, _ = self.evidence_cross_attn(x2, evidence, evidence); x = r + x2
165
+ r = x; x = r + self.ffn(self.ffn_norm(x))
166
+ return x
167
+
168
+
169
+ class GenerativeHead(nn.Module):
170
+ """
171
+ Phase 3.1 generative decoder with:
172
+ - Scheduled sampling during training (teacher forcing warmup)
173
+ - Beam search during evaluation
174
+ """
175
+ def __init__(self, hidden_dim, vocab_size, num_layers=4, num_heads=12,
176
+ max_gen_len=32, dropout=0.1):
177
+ super().__init__()
178
+ self.hidden_dim = hidden_dim
179
+ self.vocab_size = vocab_size
180
+ self.max_gen_len = max_gen_len
181
+ self.token_embedding = nn.Embedding(vocab_size, hidden_dim)
182
+ self.pos_embedding = nn.Embedding(max_gen_len, hidden_dim)
183
+ self.layers = nn.ModuleList([
184
+ GenerativeDecoderLayer(hidden_dim, num_heads, dropout) for _ in range(num_layers)
185
+ ])
186
+ self.output_norm = nn.LayerNorm(hidden_dim)
187
+ self.lm_head = nn.Linear(hidden_dim, vocab_size, bias=False)
188
+ self.lm_head.weight = self.token_embedding.weight
189
+
190
+ def _decode_step(self, token_ids, z_final, evidence):
191
+ """Run decoder on a token sequence, return logits for the last position."""
192
+ seq_len = token_ids.size(1)
193
+ positions = torch.arange(seq_len, device=token_ids.device).unsqueeze(0)
194
+ x = self.token_embedding(token_ids) + self.pos_embedding(positions)
195
+ causal_mask = torch.triu(
196
+ torch.ones(seq_len, seq_len, device=token_ids.device, dtype=torch.bool), diagonal=1
197
+ )
198
+ for layer in self.layers:
199
+ x = layer(x, z_final, evidence, causal_mask)
200
+ logits = self.lm_head(self.output_norm(x))
201
+ return logits
202
+
203
+ def forward(self, z_final, evidence, target_ids, pad_token_id=0,
204
+ teacher_forcing_ratio=1.0):
205
+ """
206
+ Training forward with scheduled sampling.
207
+
208
+ teacher_forcing_ratio=1.0 → pure teacher forcing (use ground truth at every step)
209
+ teacher_forcing_ratio=0.5 → 50% of tokens use model's own prediction
210
+ """
211
+ B, seq_len = target_ids.shape
212
+ device = target_ids.device
213
+
214
+ if teacher_forcing_ratio >= 1.0:
215
+ # ── Pure teacher forcing (fast, batched) ──
216
+ logits = self._decode_step(target_ids, z_final, evidence)
217
+ else:
218
+ # ── Scheduled sampling: mix teacher forcing with free-running ──
219
+ logits = torch.zeros(B, seq_len, self.vocab_size, device=device)
220
+ current_input = target_ids[:, :1] # start with first token
221
+
222
+ for t in range(seq_len):
223
+ step_logits = self._decode_step(current_input, z_final, evidence)
224
+ logits[:, t] = step_logits[:, -1] # logits at last position
225
+
226
+ if t < seq_len - 1:
227
+ # Decide: teacher forcing or free-running for next input
228
+ use_teacher = random.random() < teacher_forcing_ratio
229
+ if use_teacher:
230
+ next_token = target_ids[:, t + 1:t + 2]
231
+ else:
232
+ next_token = step_logits[:, -1].argmax(dim=-1, keepdim=True)
233
+ current_input = torch.cat([current_input, next_token], dim=1)
234
+
235
+ # Loss: next-token prediction
236
+ shift_logits = logits[:, :-1].contiguous()
237
+ shift_labels = target_ids[:, 1:].contiguous()
238
+ loss = F.cross_entropy(
239
+ shift_logits.view(-1, self.vocab_size),
240
+ shift_labels.view(-1),
241
+ ignore_index=pad_token_id,
242
+ )
243
+ return logits, loss
244
+
245
+ @torch.no_grad()
246
+ def generate_greedy(self, z_final, evidence, start_token_id,
247
+ max_length=32, eos_token_id=None):
248
+ """Greedy autoregressive generation (fallback)."""
249
+ B = z_final.size(0)
250
+ device = z_final.device
251
+ generated = torch.full((B, 1), start_token_id, dtype=torch.long, device=device)
252
+ for step in range(max_length - 1):
253
+ logits = self._decode_step(generated, z_final, evidence)
254
+ next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
255
+ generated = torch.cat([generated, next_token], dim=1)
256
+ if eos_token_id is not None and (next_token == eos_token_id).all():
257
+ break
258
+ return generated
259
+
260
+ @torch.no_grad()
261
+ def generate_beam(self, z_final, evidence, start_token_id,
262
+ max_length=32, eos_token_id=None, beam_width=5):
263
+ """
264
+ Beam search generation.
265
+
266
+ Processes each sample in the batch independently with beam search.
267
+ Returns the highest-scoring complete sequence per sample.
268
+ """
269
+ B = z_final.size(0)
270
+ device = z_final.device
271
+ all_results = []
272
+
273
+ for b in range(B):
274
+ z_b = z_final[b:b+1] # (1, N_s, D)
275
+ ev_b = evidence[b:b+1] # (1, N_e, D)
276
+
277
+ # Each beam: (log_prob, token_ids_tensor)
278
+ beams = [(0.0, torch.tensor([[start_token_id]], dtype=torch.long, device=device))]
279
+ completed = []
280
+
281
+ for step in range(max_length - 1):
282
+ candidates = []
283
+ for score, seq in beams:
284
+ if eos_token_id is not None and seq[0, -1].item() == eos_token_id:
285
+ completed.append((score, seq))
286
+ continue
287
+
288
+ logits = self._decode_step(seq, z_b, ev_b) # (1, T, V)
289
+ log_probs = F.log_softmax(logits[0, -1], dim=-1) # (V,)
290
+
291
+ topk_lp, topk_ids = log_probs.topk(beam_width)
292
+ for k in range(beam_width):
293
+ new_score = score + topk_lp[k].item()
294
+ new_seq = torch.cat([seq, topk_ids[k:k+1].unsqueeze(0)], dim=1)
295
+ candidates.append((new_score, new_seq))
296
+
297
+ if not candidates:
298
+ break
299
+
300
+ # Length-normalize scores and keep top beams
301
+ candidates.sort(key=lambda x: x[0] / x[1].size(1), reverse=True)
302
+ beams = candidates[:beam_width]
303
+
304
+ # Early stop if all beams ended
305
+ if all(eos_token_id is not None and seq[0, -1].item() == eos_token_id
306
+ for _, seq in beams):
307
+ completed.extend(beams)
308
+ break
309
+
310
+ # Merge completed and remaining, pick best
311
+ all_beams = completed + beams
312
+ if all_beams:
313
+ best = max(all_beams, key=lambda x: x[0] / max(x[1].size(1), 1))
314
+ all_results.append(best[1])
315
+ else:
316
+ all_results.append(torch.tensor([[start_token_id]], dtype=torch.long, device=device))
317
+
318
+ # Pad to same length
319
+ max_len = max(r.size(1) for r in all_results)
320
+ padded = torch.full((B, max_len), 0, dtype=torch.long, device=device)
321
+ for i, r in enumerate(all_results):
322
+ padded[i, :r.size(1)] = r[0]
323
+ return padded
324
+
325
+
326
+ # ══════════════════════════════════════════════════════════════════════════
327
+ # EVALUATION METRICS (same as Phase 3.0)
328
+ # ══════════════════════════════════════════════════════════════════════════
329
+
330
+ def normalized_levenshtein(s1, s2):
331
+ s1, s2 = s1.lower().strip(), s2.lower().strip()
332
+ if s1 == s2: return 0.0
333
+ l1, l2 = len(s1), len(s2)
334
+ if l1 == 0 or l2 == 0: return 1.0
335
+ m = [[0]*(l2+1) for _ in range(l1+1)]
336
+ for i in range(l1+1): m[i][0] = i
337
+ for j in range(l2+1): m[0][j] = j
338
+ for i in range(1,l1+1):
339
+ for j in range(1,l2+1):
340
+ c = 0 if s1[i-1]==s2[j-1] else 1
341
+ m[i][j] = min(m[i-1][j]+1, m[i][j-1]+1, m[i-1][j-1]+c)
342
+ return m[l1][l2]/max(l1,l2)
343
+
344
+ def compute_anls(predictions, ground_truths, threshold=0.5):
345
+ scores = []
346
+ for pred, gts in zip(predictions, ground_truths):
347
+ mx = max((1.0-normalized_levenshtein(str(pred),str(gt)) if normalized_levenshtein(str(pred),str(gt))<threshold else 0.0) for gt in gts) if gts else 0.0
348
+ scores.append(mx)
349
+ return np.mean(scores)*100 if scores else 0.0
350
+
351
+ def compute_vqa_accuracy(predictions, ground_truths):
352
+ scores = []
353
+ for pred, gts in zip(predictions, ground_truths):
354
+ pn = str(pred).lower().strip()
355
+ scores.append(min(sum(1 for gt in gts if str(gt).lower().strip()==pn)/3.0, 1.0))
356
+ return np.mean(scores)*100 if scores else 0.0
357
+
358
+ def compute_relaxed_accuracy(predictions, ground_truths, tolerance=0.05):
359
+ correct = []
360
+ for pred, gt in zip(predictions, ground_truths):
361
+ ps, gs = str(pred).strip().lower(), str(gt).strip().lower()
362
+ try:
363
+ gv = float(gs.replace(',','').replace('%',''))
364
+ pv = float(ps.replace(',','').replace('%',''))
365
+ correct.append(abs(pv-gv)/abs(gv)<=tolerance if gv!=0 else abs(pv)<=tolerance)
366
+ except (ValueError,ZeroDivisionError):
367
+ correct.append(ps==gs)
368
+ return np.mean(correct)*100 if correct else 0.0
369
+
370
+
371
+ # ══════════════════════════════════════════════════════════════════════════
372
+ # SCHEDULED SAMPLING SCHEDULE
373
+ # ══════════════════════════════════════════════════════════════════════════
374
+
375
+ def get_teacher_forcing_ratio(epoch, total_epochs, start_ratio=1.0, end_ratio=0.5):
376
+ """
377
+ Linear decay from start_ratio to end_ratio over training.
378
+ Epoch 0: 100% teacher forcing (pure ground truth).
379
+ Final epoch: 50% teacher forcing (half free-running).
380
+
381
+ This bridges the train/eval gap: during eval the model generates freely,
382
+ so training must gradually expose it to its own predictions.
383
+ """
384
+ if total_epochs <= 1:
385
+ return start_ratio
386
+ progress = epoch / (total_epochs - 1)
387
+ return start_ratio - (start_ratio - end_ratio) * progress
388
+
389
+
390
+ # ══════════════════════════════════════════════════════════════════════════
391
+ # MAIN
392
+ # ══════════════════════════════════════════════════════════════════════════
393
+
394
+ def download_checkpoint(hub_model_id, filename):
395
+ from huggingface_hub import hf_hub_download
396
+ path = hf_hub_download(repo_id=hub_model_id, filename=filename, repo_type="model")
397
+ log.info(f"Downloaded checkpoint: {path}")
398
+ return path
399
+
400
+
401
+ def main():
402
+ parser = argparse.ArgumentParser(description="MR-JEPA Phase 3.1 Training")
403
+ parser.add_argument("--checkpoint", type=str, default=None,
404
+ help="Local path to checkpoint. Default: download Phase 3.0 from Hub.")
405
+ parser.add_argument("--hub_model_id", default="JorgeAV/MR-JEPA")
406
+ parser.add_argument("--run_name", default="hybrid_main_phase3_1")
407
+ parser.add_argument("--epochs", type=int, default=10)
408
+ parser.add_argument("--batch_size", type=int, default=8)
409
+ parser.add_argument("--grad_accum", type=int, default=16)
410
+ parser.add_argument("--core_lr", type=float, default=5e-5)
411
+ parser.add_argument("--backbone_lr", type=float, default=5e-6)
412
+ parser.add_argument("--text_lr", type=float, default=5e-6)
413
+ # ── Phase 3.1 improvements ──
414
+ parser.add_argument("--gen_weight", type=float, default=2.0,
415
+ help="Generative loss weight (was 0.5 in 3.0)")
416
+ parser.add_argument("--max_gen_len", type=int, default=32,
417
+ help="Max generation length (was 64 in 3.0)")
418
+ parser.add_argument("--beam_width", type=int, default=5,
419
+ help="Beam search width for evaluation (was greedy in 3.0)")
420
+ parser.add_argument("--tf_start", type=float, default=1.0,
421
+ help="Teacher forcing ratio at epoch 0")
422
+ parser.add_argument("--tf_end", type=float, default=0.5,
423
+ help="Teacher forcing ratio at final epoch")
424
+ # ─────────────────────────��────
425
+ parser.add_argument("--max_eval_samples", type=int, default=200)
426
+ parser.add_argument("--max_train_samples", type=int, default=0)
427
+ parser.add_argument("--output_dir", default="./outputs/mrjepa_phase3_1")
428
+ parser.add_argument("--trackio_space", default="JorgeAV/MR-JEPA-Trackio")
429
+ args = parser.parse_args()
430
+
431
+ # ── Import Phase 1 model definitions ──
432
+ log.info("Downloading Phase 1 training script for model definitions...")
433
+ from huggingface_hub import hf_hub_download
434
+ p1_script = hf_hub_download(repo_id=args.hub_model_id, filename="train_mrjepa.py", repo_type="model")
435
+ import importlib.util
436
+ spec = importlib.util.spec_from_file_location("train_mrjepa", p1_script)
437
+ p1 = importlib.util.module_from_spec(spec)
438
+ spec.loader.exec_module(p1)
439
+
440
+ # ── Load Phase 3.0 checkpoint (includes gen_head weights) ──
441
+ if args.checkpoint and os.path.exists(args.checkpoint):
442
+ ckpt_path = args.checkpoint
443
+ else:
444
+ ckpt_path = download_checkpoint(args.hub_model_id,
445
+ "checkpoints/hybrid_main_phase3_best.pt")
446
+
447
+ log.info(f"Loading Phase 3.0 checkpoint: {ckpt_path}")
448
+ ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=False)
449
+
450
+ saved_cfg = ckpt["config"]
451
+ cfg = p1.Config()
452
+ for k, v in saved_cfg.items():
453
+ if hasattr(cfg, k):
454
+ setattr(cfg, k, v)
455
+
456
+ cfg.phase = 3
457
+ cfg.epochs = args.epochs
458
+ cfg.batch_size = args.batch_size
459
+ cfg.grad_accum = args.grad_accum
460
+ cfg.lr = args.core_lr
461
+ cfg.backbone_lr = args.backbone_lr
462
+ cfg.output_dir = args.output_dir
463
+ cfg.run_name = args.run_name
464
+ cfg.freeze_backbone = True
465
+ cfg.freeze_text = True
466
+ cfg.max_eval_samples = args.max_eval_samples
467
+ cfg.resolve()
468
+
469
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
470
+ log.info(f"Device: {device}")
471
+ os.makedirs(cfg.output_dir, exist_ok=True)
472
+
473
+ # ── Trackio ──
474
+ import trackio
475
+ trackio.init(
476
+ name=args.run_name, project="MR-JEPA", space_id=args.trackio_space,
477
+ config={
478
+ "phase": "3.1", "epochs": args.epochs,
479
+ "core_lr": args.core_lr, "backbone_lr": args.backbone_lr,
480
+ "text_lr": args.text_lr, "gen_weight": args.gen_weight,
481
+ "max_gen_len": args.max_gen_len, "beam_width": args.beam_width,
482
+ "tf_start": args.tf_start, "tf_end": args.tf_end,
483
+ "batch_size": args.batch_size, "grad_accum": args.grad_accum,
484
+ "backbone": cfg.backbone, "K": cfg.K,
485
+ "improvements": "gen_weight_2.0, gen_len_32, scheduled_sampling, beam_search",
486
+ }
487
+ )
488
+ log.info(f"Trackio → https://huggingface.co/spaces/{args.trackio_space}")
489
+
490
+ # ── Build model ──
491
+ log.info("Building model...")
492
+ model = p1.MRJEPAModel(cfg)
493
+ model.evidence.load_state_dict(ckpt["evidence"])
494
+ model.rollout.load_state_dict(ckpt["rollout"])
495
+ model.disc.load_state_dict(ckpt["disc"])
496
+ model.target.t_ev.load_state_dict(ckpt["target_ev"])
497
+ model.target.t_ro.load_state_dict(ckpt["target_ro"])
498
+ log.info(f"Loaded core weights from Phase 3.0 (epoch={ckpt.get('epoch','?')}, "
499
+ f"composite={ckpt.get('composite_score','?')})")
500
+
501
+ # ── Generative head: new architecture with max_gen_len=32 ──
502
+ tokenizer = model.txt.tokenizer
503
+ actual_vocab_size = len(tokenizer)
504
+
505
+ gen_head = GenerativeHead(
506
+ hidden_dim=cfg.rollout_dim,
507
+ vocab_size=actual_vocab_size,
508
+ num_layers=4,
509
+ num_heads=cfg.predictor_heads,
510
+ max_gen_len=args.max_gen_len,
511
+ dropout=0.1,
512
+ )
513
+
514
+ # Load Phase 3.0 gen_head weights where shapes match
515
+ if "gen_head" in ckpt:
516
+ p3_gen = ckpt["gen_head"]
517
+ new_sd = gen_head.state_dict()
518
+ loaded, skipped = 0, 0
519
+ for k, v in p3_gen.items():
520
+ if k in new_sd and new_sd[k].shape == v.shape:
521
+ new_sd[k] = v
522
+ loaded += 1
523
+ elif k in new_sd:
524
+ skipped += 1
525
+ log.info(f" Shape mismatch for {k}: ckpt {v.shape} vs new {new_sd[k].shape}")
526
+ else:
527
+ skipped += 1
528
+ gen_head.load_state_dict(new_sd)
529
+ log.info(f"Loaded {loaded} gen_head params from Phase 3.0 ({skipped} skipped)")
530
+ else:
531
+ log.warning("No gen_head in checkpoint — starting from scratch")
532
+
533
+ model.gen_head = gen_head
534
+
535
+ # ── Unfreeze backbone layers ──
536
+ log.info("Unfreezing last 6 visual layers, last 4 text layers")
537
+ model.vis.unfreeze_last(6)
538
+ model.txt.unfreeze_last(4)
539
+
540
+ model = model.to(device)
541
+ total_p = sum(p.numel() for p in model.parameters())
542
+ train_p = sum(p.numel() for p in model.parameters() if p.requires_grad)
543
+ log.info(f"Total: {total_p:,} | Trainable: {train_p:,} ({100*train_p/total_p:.1f}%)")
544
+
545
+ # ── Datasets ──
546
+ transform = model.vis.get_transform()
547
+ mc_max = args.max_train_samples if args.max_train_samples > 0 else 0
548
+ train_mc_ds = p1.ScienceQADataset("train", max_samples=mc_max, transform=transform,
549
+ tokenizer=tokenizer, max_len=cfg.max_text_len,
550
+ max_opts=cfg.max_options)
551
+ eval_mc_ds = p1.ScienceQADataset("test", max_samples=cfg.max_eval_samples,
552
+ transform=transform, tokenizer=tokenizer,
553
+ max_len=cfg.max_text_len, max_opts=cfg.max_options)
554
+ mc_coll = lambda batch: p1.collate_fn(batch, transform, tokenizer, cfg.max_text_len, cfg.max_options)
555
+ train_mc_dl = DataLoader(train_mc_ds, batch_size=cfg.batch_size, shuffle=True,
556
+ num_workers=2, collate_fn=mc_coll, pin_memory=True, drop_last=True)
557
+ eval_mc_dl = DataLoader(eval_mc_ds, batch_size=cfg.batch_size, shuffle=False,
558
+ num_workers=2, collate_fn=mc_coll, pin_memory=True)
559
+
560
+ max_open = args.max_train_samples if args.max_train_samples > 0 else 5000
561
+ open_coll = lambda batch: collate_open_ended(batch, transform, tokenizer,
562
+ cfg.max_text_len, args.max_gen_len)
563
+
564
+ train_open_dls = {}
565
+ eval_open_dls = {}
566
+ for bm, tr_split, ev_split in [("docvqa","validation","validation"),
567
+ ("chartqa","test","test"),
568
+ ("textvqa","train","validation")]:
569
+ train_open_dls[bm] = DataLoader(
570
+ OpenEndedDataset(bm, tr_split, max_samples=max_open, transform=transform,
571
+ tokenizer=tokenizer, max_len=cfg.max_text_len,
572
+ max_gen_len=args.max_gen_len),
573
+ batch_size=cfg.batch_size, shuffle=True, num_workers=2,
574
+ collate_fn=open_coll, pin_memory=True, drop_last=True)
575
+ eval_open_dls[bm] = DataLoader(
576
+ OpenEndedDataset(bm, ev_split, max_samples=args.max_eval_samples,
577
+ transform=transform, tokenizer=tokenizer,
578
+ max_len=cfg.max_text_len, max_gen_len=args.max_gen_len),
579
+ batch_size=cfg.batch_size, shuffle=False, num_workers=2,
580
+ collate_fn=open_coll, pin_memory=True)
581
+
582
+ # ── Optimizer ──
583
+ backbone_params = [p for p in model.vis.parameters() if p.requires_grad]
584
+ text_params = [p for p in model.txt.parameters() if p.requires_grad]
585
+ bb_txt_ids = {id(p) for p in backbone_params + text_params}
586
+ core_params = [p for p in model.parameters() if p.requires_grad and id(p) not in bb_txt_ids]
587
+ param_groups = [
588
+ {"params": core_params, "lr": args.core_lr},
589
+ {"params": backbone_params, "lr": args.backbone_lr},
590
+ {"params": text_params, "lr": args.text_lr},
591
+ ]
592
+ optimizer = AdamW(param_groups, weight_decay=cfg.weight_decay)
593
+
594
+ mc_steps = len(train_mc_dl)
595
+ open_steps = sum(len(dl) for dl in train_open_dls.values())
596
+ total_steps = cfg.epochs * (mc_steps + open_steps) // cfg.grad_accum
597
+ warmup_steps = int(total_steps * 0.1)
598
+
599
+ def lr_lambda(step):
600
+ if step < warmup_steps:
601
+ return step / max(warmup_steps, 1)
602
+ progress = (step - warmup_steps) / max(total_steps - warmup_steps, 1)
603
+ return 0.01 + 0.99 * 0.5 * (1 + math.cos(math.pi * progress))
604
+
605
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
606
+
607
+ pad_token_id = tokenizer.pad_token_id
608
+ if pad_token_id is None:
609
+ pad_token_id = tokenizer.eos_token_id or 0
610
+
611
+ log.info(f"Phase 3.1: {cfg.epochs} epochs | gen_weight={args.gen_weight} | "
612
+ f"max_gen_len={args.max_gen_len} | beam_width={args.beam_width}")
613
+ log.info(f" Teacher forcing: {args.tf_start:.0%} → {args.tf_end:.0%}")
614
+ log.info(f" MC batches/epoch: {mc_steps} | Open batches/epoch: {open_steps}")
615
+ log.info(f" Total opt steps: ~{total_steps} | Warmup: {warmup_steps}")
616
+
617
+ global_step = 0
618
+ best_composite = 0.0
619
+ amp_dtype = torch.bfloat16 if cfg.bf16 else torch.float32
620
+ trainable = [p for p in model.parameters() if p.requires_grad]
621
+
622
+ try:
623
+ for epoch in range(cfg.epochs):
624
+ model.train()
625
+ epoch_losses = defaultdict(list)
626
+ epoch_mc_correct, epoch_mc_total = 0, 0
627
+ optimizer.zero_grad()
628
+ batch_count = 0
629
+
630
+ # ── Scheduled sampling ratio for this epoch ──
631
+ tf_ratio = get_teacher_forcing_ratio(epoch, cfg.epochs, args.tf_start, args.tf_end)
632
+ log.info(f"Phase 3.1 Epoch {epoch}: teacher_forcing={tf_ratio:.2f}")
633
+
634
+ # ── MC training ──
635
+ log.info(f" MC training on ScienceQA...")
636
+ for bi, batch in enumerate(train_mc_dl):
637
+ batch = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
638
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
639
+ losses, preds = model(**batch)
640
+ loss = losses["total"] / cfg.grad_accum
641
+ loss.backward()
642
+ batch_count += 1
643
+ if batch_count % cfg.grad_accum == 0:
644
+ nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
645
+ optimizer.step(); scheduler.step(); optimizer.zero_grad()
646
+ model.update_target(global_step, total_steps)
647
+ global_step += 1
648
+ for k, v in losses.items():
649
+ if isinstance(v, torch.Tensor): epoch_losses[f"mc_{k}"].append(v.item())
650
+ epoch_mc_correct += (preds == batch["labels"]).sum().item()
651
+ epoch_mc_total += batch["batch_size"]
652
+ if bi % 100 == 0:
653
+ avg = {k: np.mean(v[-100:]) for k, v in epoch_losses.items() if k.startswith("mc_")}
654
+ acc = epoch_mc_correct / max(epoch_mc_total, 1) * 100
655
+ log.info(f" E{epoch} MC B{bi}/{mc_steps} | loss={avg.get('mc_total',0):.4f} | acc={acc:.1f}%")
656
+ trackio.log({"train/mc_loss": avg.get("mc_total",0), "train/mc_accuracy": acc,
657
+ "train/lr": scheduler.get_last_lr()[0], "train/epoch": epoch,
658
+ "train/step": global_step, "train/tf_ratio": tf_ratio})
659
+
660
+ # ── Open-ended training (with scheduled sampling) ──
661
+ log.info(f" Open-ended training (tf_ratio={tf_ratio:.2f})...")
662
+ gen_losses = defaultdict(list)
663
+ open_iters = {n: iter(dl) for n, dl in train_open_dls.items()}
664
+ open_active = set(open_iters.keys())
665
+ obi = 0
666
+ while open_active:
667
+ for name in list(open_active):
668
+ try:
669
+ batch = next(open_iters[name])
670
+ except StopIteration:
671
+ open_active.discard(name); continue
672
+ bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
673
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
674
+ vis_tok = model.vis(bt["pixel_values"]).float()
675
+ txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float()
676
+ evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"])
677
+ if model._use_rollout:
678
+ traj, z_final, z_proj = model.rollout(evidence)
679
+ else:
680
+ B2 = bt["batch_size"]
681
+ z0 = model.rollout.init_tokens.expand(B2,-1,-1) + \
682
+ model.rollout.z0_proj(F.adaptive_avg_pool1d(
683
+ evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
684
+ z_final, z_proj = z0, model.rollout.out_proj(z0).unsqueeze(1)
685
+
686
+ jepa_loss_val = torch.tensor(0.0, device=device)
687
+ if model._use_jepa:
688
+ target_proj = model.target(vis_tok.detach(), txt_tok.detach(), bt["attention_mask"].detach())
689
+ jl = model.jepa_loss(z_proj, target_proj, torch.tensor(0.0, device=device))
690
+ jepa_loss_val = jl["jepa"] + jl["reg"]
691
+
692
+ # ── Generative loss with scheduled sampling ──
693
+ _, gen_loss = model.gen_head(
694
+ z_final, evidence, bt["gen_target_ids"],
695
+ pad_token_id=pad_token_id,
696
+ teacher_forcing_ratio=tf_ratio,
697
+ )
698
+
699
+ total_loss = cfg.jepa_weight * jepa_loss_val + args.gen_weight * gen_loss
700
+ loss = total_loss / cfg.grad_accum
701
+
702
+ loss.backward()
703
+ batch_count += 1
704
+ if batch_count % cfg.grad_accum == 0:
705
+ nn.utils.clip_grad_norm_(trainable, cfg.max_grad_norm)
706
+ optimizer.step(); scheduler.step(); optimizer.zero_grad()
707
+ model.update_target(global_step, total_steps); global_step += 1
708
+
709
+ gen_losses[f"{name}_gen"].append(gen_loss.item())
710
+ gen_losses[f"{name}_total"].append(total_loss.item())
711
+ obi += 1
712
+ if obi % 100 == 0:
713
+ avg = {k: np.mean(v[-100:]) for k, v in gen_losses.items()}
714
+ log.info(f" E{epoch} OPEN B{obi} | " + " | ".join(f"{k}={v:.4f}" for k,v in avg.items()))
715
+ trackio.log({f"train/{k}": v for k, v in avg.items()})
716
+
717
+ # ── Evaluation (with beam search) ──
718
+ log.info(f" Evaluating (beam_width={args.beam_width})...")
719
+ mc_eval_acc = p1.evaluate(model, eval_mc_dl, device, cfg)
720
+ log.info(f" ScienceQA eval accuracy: {mc_eval_acc:.1f}%")
721
+
722
+ eval_results = evaluate_generative_beam(
723
+ model, eval_open_dls, device, cfg, tokenizer,
724
+ args.max_gen_len, amp_dtype, args.beam_width
725
+ )
726
+ for bm, metrics in eval_results.items():
727
+ for mk, mv in metrics.items():
728
+ log.info(f" {bm} {mk}: {mv:.2f}")
729
+
730
+ all_scores = [mc_eval_acc] + [v for m in eval_results.values() for v in m.values()]
731
+ composite = np.mean(all_scores)
732
+ log.info(f"=== Phase 3.1 Epoch {epoch} | MC: {mc_eval_acc:.1f}% | "
733
+ f"Composite: {composite:.1f} | tf={tf_ratio:.2f} ===")
734
+
735
+ trackio.log({
736
+ "eval/scienceqa_accuracy": mc_eval_acc,
737
+ "eval/composite_score": composite,
738
+ "eval/epoch": epoch, "eval/tf_ratio": tf_ratio,
739
+ **{f"eval/{bm}_{mk}": mv for bm, m in eval_results.items() for mk, mv in m.items()},
740
+ })
741
+
742
+ if composite > best_composite:
743
+ best_composite = composite
744
+ save_checkpoint(model, cfg, epoch, mc_eval_acc, eval_results, composite)
745
+ log.info(f" ★ New best composite: {best_composite:.1f}")
746
+
747
+ log.info(f"Phase 3.1 complete. Best composite: {best_composite:.1f}")
748
+
749
+ finally:
750
+ trackio.log({"final/best_composite": best_composite, "final/phase": "3.1",
751
+ "final/total_steps": global_step})
752
+ trackio.finish()
753
+
754
+ if cfg.push_to_hub:
755
+ push_results(cfg, args, best_composite, eval_results)
756
+
757
+
758
+ # ══════════════════════════════════════════════════════════════════════════
759
+ # BEAM SEARCH EVALUATION
760
+ # ══════════════════════════════════════════════════════════════════════════
761
+
762
+ @torch.no_grad()
763
+ def evaluate_generative_beam(model, eval_dls, device, cfg, tokenizer,
764
+ max_gen_len, amp_dtype, beam_width):
765
+ """Evaluate open-ended benchmarks using beam search decoding."""
766
+ model.eval()
767
+ results = {}
768
+ start_token_id = tokenizer.bos_token_id or tokenizer.cls_token_id or 1
769
+ eos_token_id = tokenizer.eos_token_id
770
+
771
+ for benchmark, dl in eval_dls.items():
772
+ predictions, ground_truths = [], []
773
+ for batch in dl:
774
+ bt = {k: v.to(device) if isinstance(v, torch.Tensor) else v for k, v in batch.items()}
775
+ with torch.autocast(device_type="cuda", dtype=amp_dtype, enabled=cfg.bf16 and device.type=="cuda"):
776
+ vis_tok = model.vis(bt["pixel_values"]).float()
777
+ txt_tok = model.txt(bt["input_ids"], bt["attention_mask"]).float()
778
+ evidence, _, _ = model.evidence(vis_tok, txt_tok, bt["attention_mask"])
779
+ if model._use_rollout:
780
+ _, z_final, _ = model.rollout(evidence)
781
+ else:
782
+ B2 = bt["batch_size"]
783
+ z_final = model.rollout.init_tokens.expand(B2,-1,-1) + model.rollout.z0_proj(
784
+ F.adaptive_avg_pool1d(evidence.permute(0,2,1), model.rollout.num_tokens).permute(0,2,1))
785
+
786
+ gen_ids = model.gen_head.generate_beam(
787
+ z_final, evidence, start_token_id,
788
+ max_length=max_gen_len, eos_token_id=eos_token_id,
789
+ beam_width=beam_width,
790
+ )
791
+ for i in range(gen_ids.size(0)):
792
+ predictions.append(tokenizer.decode(gen_ids[i], skip_special_tokens=True).strip())
793
+ ground_truths.extend(batch["all_answers"])
794
+
795
+ # Log a few sample predictions for debugging
796
+ for j in range(min(3, len(predictions))):
797
+ gt_sample = ground_truths[j] if j < len(ground_truths) else "?"
798
+ log.info(f" [{benchmark}] pred: '{predictions[j]}' | gt: '{gt_sample}'")
799
+
800
+ if benchmark == "docvqa":
801
+ results[benchmark] = {"anls": compute_anls(predictions, ground_truths)}
802
+ elif benchmark == "chartqa":
803
+ gt_flat = [g[0] if isinstance(g, list) else g for g in ground_truths]
804
+ results[benchmark] = {"relaxed_accuracy": compute_relaxed_accuracy(predictions, gt_flat)}
805
+ elif benchmark == "textvqa":
806
+ results[benchmark] = {"vqa_accuracy": compute_vqa_accuracy(predictions, ground_truths)}
807
+
808
+ model.train()
809
+ return results
810
+
811
+
812
+ # ══════════════════════════════════════════════════════════════════════════
813
+ # CHECKPOINT & HUB
814
+ # ══════════════════════════════════════════════════════════════════════════
815
+
816
+ def save_checkpoint(model, cfg, epoch, mc_acc, open_results, composite):
817
+ path = os.path.join(cfg.output_dir, "checkpoint_best.pt")
818
+ torch.save({
819
+ "evidence": model.evidence.state_dict(),
820
+ "rollout": model.rollout.state_dict(),
821
+ "disc": model.disc.state_dict(),
822
+ "gen_head": model.gen_head.state_dict(),
823
+ "target_ev": model.target.t_ev.state_dict(),
824
+ "target_ro": model.target.t_ro.state_dict(),
825
+ "config": cfg.__dict__,
826
+ "epoch": epoch, "mc_eval_acc": mc_acc,
827
+ "open_results": open_results, "composite_score": composite,
828
+ "phase": "3.1",
829
+ }, path)
830
+ log.info(f"Saved checkpoint: {path} (composite={composite:.1f})")
831
+
832
+
833
+ def push_results(cfg, args, best_composite, eval_results):
834
+ try:
835
+ from huggingface_hub import HfApi
836
+ api = HfApi()
837
+ results = {
838
+ "run_name": cfg.run_name, "phase": "3.1",
839
+ "backbone": cfg.backbone, "K": cfg.K,
840
+ "best_composite_score": best_composite,
841
+ "gen_weight": args.gen_weight, "max_gen_len": args.max_gen_len,
842
+ "beam_width": args.beam_width,
843
+ "tf_start": args.tf_start, "tf_end": args.tf_end,
844
+ "epochs": cfg.epochs, "core_lr": args.core_lr,
845
+ "open_results": {k: v for k, v in (eval_results or {}).items()},
846
+ "improvements": ["gen_weight_2.0", "gen_len_32", "scheduled_sampling", "beam_search"],
847
+ }
848
+ rp = os.path.join(cfg.output_dir, f"results_{cfg.run_name}.json")
849
+ with open(rp, "w") as f:
850
+ json.dump(results, f, indent=2)
851
+ api.upload_file(path_or_fileobj=rp, path_in_repo=f"results/{cfg.run_name}.json",
852
+ repo_id=cfg.hub_model_id, repo_type="model")
853
+ best_ckpt = os.path.join(cfg.output_dir, "checkpoint_best.pt")
854
+ if os.path.exists(best_ckpt):
855
+ api.upload_file(path_or_fileobj=best_ckpt,
856
+ path_in_repo=f"checkpoints/{cfg.run_name}_best.pt",
857
+ repo_id=cfg.hub_model_id, repo_type="model")
858
+ log.info(f"Pushed Phase 3.1 results to {cfg.hub_model_id}")
859
+ except Exception as e:
860
+ log.error(f"Push failed: {e}")
861
+
862
+
863
+ if __name__ == "__main__":
864
+ main()