ChristophSchuhmann commited on
Commit
2cd248d
·
verified ·
1 Parent(s): db42eb1

Upload train+.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. train+.py +1061 -0
train+.py ADDED
@@ -0,0 +1,1061 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ #!/usr/bin/env python3
3
+ # -*- coding: utf-8 -*-
4
+ """
5
+ Local single-process trainer (no torchrun/DDP).
6
+ - Uses all visible GPUs via torch.nn.DataParallel (if >1 GPU), else single GPU.
7
+ - Trains Whisper encoder (unfrozen) + small head; decoder is frozen (unused).
8
+ - Supports data in .tar/.tar.gz (audio+json pairs inside) OR loose files:
9
+ <any>/<name>.wav|.mp3 + <same-dir>/<name>.json
10
+ - NEW: also supports GeminiProAudioSegments-style loose files:
11
+ <any>/<name>.audio.mp3 + <same-dir>/<name>.audio.json
12
+ with filtering on segment_duration + overlapping.
13
+
14
+ - Adaptive batch probe (optional), BF16 preferred when supported (auto), FP16 fallback.
15
+ - Periodic HTML evals (x-axis = seconds), ETA, full resumability (weights + states).
16
+ - HTML eval embeds <audio> player with base64 audio + plots per sample.
17
+
18
+ NOTE: "Resume" now means **weights-only resume** for a new phase on a (possibly) new dataset:
19
+ - We load model weights from trainer_state.pt / trainer_state_best.pt,
20
+ but reset optimizer, scheduler, and all counters for this run.
21
+ """
22
+
23
+ from __future__ import annotations
24
+ import os, io, json, time, random, tarfile, base64, traceback, math
25
+ from pathlib import Path
26
+ from typing import List, Tuple, Dict, Any, Optional
27
+
28
+ # =========================
29
+ # ========= CONFIG ========
30
+ # =========================
31
+ DATA_DIR = Path(os.getenv("DATA_DIR", "./audiodata-full"))
32
+ RESUME_DIR = Path(os.getenv("RESUME_DIR", "./resume"))
33
+ OUT_DIR = Path(os.getenv("OUT_DIR", "./outs"))
34
+ EPOCHS = int(os.getenv("EPOCHS", "2"))
35
+ BATCH_SIZE = int(os.getenv("BATCH_SIZE", "16")) # global batch (DataParallel will split)
36
+ ADAPTIVE_BSZ = int(os.getenv("ADAPTIVE_BSZ", "1")) # 1=probe; 0=use BATCH_SIZE as-is
37
+ MAX_BSZ_CAP = int(os.getenv("MAX_BSZ", "0")) or None
38
+ NUM_WORKERS = int(os.getenv("NUM_WORKERS", "4"))
39
+ VAL_POOL = int(os.getenv("EVAL_POOL", "1000")) # kept for backward compat; not used in new mix
40
+ EVAL_FIRST = int(os.getenv("EVAL_FIRST_SEEN","2000"))
41
+ EVAL_EVERY = int(os.getenv("EVAL_EVERY_SEEN","10000"))
42
+ SEED = int(os.getenv("SEED", "1337"))
43
+ HF_MODEL_ID = os.getenv("HF_MODEL_ID", "openai/whisper-small")
44
+
45
+ # --- NEW: Gemini-specific config (hard-coded but override-able via env) ---
46
+ GEMINI_DIR = Path(os.getenv("GEMINI_DIR", "/home/user/segdata-full/"))
47
+ USE_GEMINI = int(os.getenv("USE_GEMINI", "1")) # 1=use Gemini bucket, 0=ignore
48
+ GEMINI_SEGMENT_DURATION = os.getenv("GEMINI_SEGMENT_DURATION", "medium") # filter on this
49
+ GEMINI_INCLUDE_OVERLAP_TRUE = bool(int(os.getenv("GEMINI_INCLUDE_OVERLAP_TRUE", "1")))
50
+ GEMINI_INCLUDE_OVERLAP_FALSE = bool(int(os.getenv("GEMINI_INCLUDE_OVERLAP_FALSE", "1")))
51
+ GEMINI_OTHER_RATIO = float(os.getenv("GEMINI_OTHER_RATIO", "0.50")) # other bucket size = round(ratio * N_gem)
52
+ VAL_FIXED_N = int(os.getenv("VAL_FIXED_N", "500")) # fixed eval size from mixed pool
53
+
54
+ # Optional offline model snapshot
55
+ USE_LOCAL_MODELS = bool(int(os.getenv("USE_LOCAL_MODELS", "0")))
56
+ MODELS_SNAPSHOT_DIR= Path(os.getenv("MODELS_SNAPSHOT_DIR", "")) if USE_LOCAL_MODELS else None
57
+ HF_HOME = Path(os.getenv("HF_HOME", (OUT_DIR / ".hf")))
58
+ TRANSFORMERS_CACHE = Path(os.getenv("TRANSFORMERS_CACHE", (OUT_DIR / ".hf" / "hub")))
59
+
60
+ # Mixed precision: "auto" -> bf16 if supported else fp16; or "bf16"/"fp16"/"fp32"
61
+ MIXED_PRECISION = os.getenv("MIXED_PRECISION", "auto").lower()
62
+
63
+ # Optim/schedule
64
+ LR = 2e-4 # slightly higher LR for the new phase
65
+ WEIGHT_DECAY = 1e-3
66
+ WARMUP_RATIO = 0.05
67
+ SCHEDULER = os.getenv("SCHEDULER", "cosine") # cosine|linear
68
+ FREEZE_ENCODER = False
69
+ PIN_MEMORY = True
70
+ GRAD_CLIP_NORM = 1.0
71
+ INCLUDE_BG_IN_ACC = False
72
+
73
+ # Resume / init behaviour
74
+ # RESUME_MODE: "latest" (default), "best", or "none"
75
+ # Now used only to choose which checkpoint to load **weights** from.
76
+ RESUME_MODE = os.getenv("RESUME_MODE", "latest").lower()
77
+ INIT_WEIGHTS_STR = os.getenv("INIT_WEIGHTS", "").strip()
78
+ INIT_WEIGHTS = Path(INIT_WEIGHTS_STR) if INIT_WEIGHTS_STR else None
79
+
80
+ # Data/model constants
81
+ SAMPLE_RATE = 16000
82
+ CLIP_SECONDS = 30.0
83
+ NUM_FRAMES = 1500
84
+ NUM_TRACKS = 2
85
+ MAX_SEGMENTS = 20
86
+
87
+ LOG_EVERY = 50
88
+ HTML_TOP_N = 12
89
+
90
+ # =========================
91
+ # ========= IMPORTS =======
92
+ # =========================
93
+ import numpy as np
94
+ import torch
95
+ import torch.nn as nn
96
+ import torch.nn.functional as F
97
+ from torch.utils.data import Dataset, DataLoader
98
+ from torch.nn import DataParallel
99
+
100
+ # Headless plotting
101
+ import matplotlib
102
+ matplotlib.use("Agg")
103
+ import matplotlib.pyplot as plt
104
+
105
+ from transformers import (
106
+ WhisperFeatureExtractor,
107
+ WhisperModel,
108
+ get_cosine_schedule_with_warmup,
109
+ get_linear_schedule_with_warmup,
110
+ )
111
+
112
+ # =========================
113
+ # ======= UTILITIES =======
114
+ # =========================
115
+ def setup_dirs():
116
+ OUT_DIR.mkdir(parents=True, exist_ok=True)
117
+ RESUME_DIR.mkdir(parents=True, exist_ok=True)
118
+ (OUT_DIR / ".mplconfig").mkdir(parents=True, exist_ok=True)
119
+ os.environ.setdefault("MPLCONFIGDIR", str((OUT_DIR / ".mplconfig").resolve()))
120
+ HF_HOME.mkdir(parents=True, exist_ok=True)
121
+ os.environ.setdefault("HF_HOME", str(HF_HOME.resolve()))
122
+ os.environ.setdefault("TRANSFORMERS_CACHE", str(TRANSFORMERS_CACHE.resolve()))
123
+ # allocator (PyTorch >=2.x)
124
+ os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True,max_split_size_mb:128")
125
+
126
+ def set_seed(s: int):
127
+ random.seed(s); np.random.seed(s)
128
+ torch.manual_seed(s); torch.cuda.manual_seed_all(s)
129
+
130
+ def preferred_dtype():
131
+ if MIXED_PRECISION == "bf16": return torch.bfloat16
132
+ if MIXED_PRECISION == "fp16": return torch.float16
133
+ if MIXED_PRECISION == "fp32": return torch.float32
134
+ if torch.cuda.is_available() and torch.cuda.is_bf16_supported():
135
+ return torch.bfloat16
136
+ return torch.float16 if torch.cuda.is_available() else torch.float32
137
+
138
+ def _model_resolved_name(model_id: str) -> Tuple[str, bool]:
139
+ if USE_LOCAL_MODELS and MODELS_SNAPSHOT_DIR and MODELS_SNAPSHOT_DIR.is_dir():
140
+ local_dirname = model_id.replace("/", "__")
141
+ cand = MODELS_SNAPSHOT_DIR / local_dirname
142
+ if cand.is_dir():
143
+ return str(cand), True
144
+ return model_id, False
145
+
146
+ # =========================
147
+ # ========= DATA ==========
148
+ # =========================
149
+ ACCEPT_EXT = {".mp3", ".wav"}
150
+
151
+ def index_tar_pairs_streaming(tar_path: Path) -> List[Tuple[str,str]]:
152
+ """
153
+ Returns list of (audio_member_name, json_member_name) inside a tar(ball).
154
+ """
155
+ pairs, mapping = [], {}
156
+ try:
157
+ with tarfile.open(tar_path, mode="r|*", ignore_zeros=True) as tf:
158
+ for m in tf:
159
+ if not m.isreg():
160
+ continue
161
+ base, ext = os.path.splitext(m.name)
162
+ ext = ext.lower()
163
+ if ext in ACCEPT_EXT:
164
+ mapping.setdefault(base, {})["audio"] = m.name
165
+ elif ext == ".json":
166
+ mapping.setdefault(base, {})["json"] = m.name
167
+ except Exception:
168
+ return []
169
+ for base, d in mapping.items():
170
+ if "audio" in d and "json" in d:
171
+ pairs.append((d["audio"], d["json"]))
172
+ return pairs
173
+
174
+ def index_loose_pairs(root: Path) -> List[Tuple[Path,Path]]:
175
+ """
176
+ Returns list of (audio_path, json_path) under root for loose files.
177
+ Pattern: <name>.(wav|mp3) + <name>.json
178
+ """
179
+ results = []
180
+ for audio in root.rglob("*"):
181
+ if not audio.is_file():
182
+ continue
183
+ if audio.suffix.lower() in ACCEPT_EXT:
184
+ j = audio.with_suffix(".json")
185
+ if j.exists():
186
+ results.append((audio, j))
187
+ return results
188
+
189
+ # --- NEW: Gemini loose-file indexer ---
190
+ def index_gemini_pairs(root: Path) -> List[Tuple[Path, Path]]:
191
+ """
192
+ Returns list of (audio_path, json_path) for Gemini-style pairs under root:
193
+ <anything>.audio.mp3 + <same>.audio.json
194
+ """
195
+ results: List[Tuple[Path, Path]] = []
196
+ if not root.is_dir():
197
+ return results
198
+ for audio in root.rglob("*.audio.mp3"):
199
+ if not audio.is_file():
200
+ continue
201
+ j = audio.with_suffix(".json") # sample_0.audio.mp3 -> sample_0.audio.json
202
+ if j.exists():
203
+ results.append((audio, j))
204
+ return results
205
+
206
+ def _safe_extract_bytes(tar_path: Path, member_name: str) -> Optional[bytes]:
207
+ try:
208
+ with tarfile.open(tar_path, mode="r:*", ignore_zeros=True) as tf:
209
+ m = tf.getmember(member_name)
210
+ f = tf.extractfile(m)
211
+ return f.read() if f else None
212
+ except Exception:
213
+ pass
214
+ try:
215
+ with tarfile.open(tar_path, mode="r|*", ignore_zeros=True) as tf:
216
+ for m in tf:
217
+ if m.isreg() and m.name == member_name:
218
+ f = tf.extractfile(m)
219
+ return f.read() if f else None
220
+ except Exception:
221
+ pass
222
+ return None
223
+
224
+ def read_json_bytes(b: Optional[bytes]) -> Dict[str, Any]:
225
+ if not b:
226
+ return {}
227
+ try:
228
+ return json.loads(b.decode("utf-8", errors="replace"))
229
+ except Exception:
230
+ return {}
231
+
232
+ def read_member_json(tar_path: Path, member_name: str) -> Dict[str, Any]:
233
+ return read_json_bytes(_safe_extract_bytes(tar_path, member_name))
234
+
235
+ def read_member_audio_30s(tar_path: Path, member_name: str) -> np.ndarray:
236
+ b = _safe_extract_bytes(tar_path, member_name)
237
+ return decode_audio_30s_bytes(b)
238
+
239
+ def read_file_json(p: Path) -> Dict[str, Any]:
240
+ try:
241
+ return json.loads(p.read_text(encoding="utf-8"))
242
+ except Exception:
243
+ return {}
244
+
245
+ def read_file_audio_30s(p: Path) -> np.ndarray:
246
+ try:
247
+ with open(p, "rb") as f:
248
+ b = f.read()
249
+ except Exception:
250
+ b = None
251
+ return decode_audio_30s_bytes(b)
252
+
253
+ def decode_audio_30s_bytes(b: Optional[bytes]) -> np.ndarray:
254
+ if not b:
255
+ return np.zeros(int(CLIP_SECONDS * SAMPLE_RATE), dtype=np.float32)
256
+ import soundfile as sf
257
+ import librosa
258
+ try:
259
+ with io.BytesIO(b) as bio:
260
+ wav, sr = sf.read(bio, dtype="float32", always_2d=False)
261
+ if wav.ndim == 2:
262
+ wav = wav.mean(axis=1)
263
+ if sr != SAMPLE_RATE:
264
+ wav = librosa.resample(wav, orig_sr=sr, target_sr=SAMPLE_RATE)
265
+ clip_samples = int(CLIP_SECONDS * SAMPLE_RATE)
266
+ if len(wav) < clip_samples:
267
+ wav = np.pad(wav, (0, clip_samples - len(wav)))
268
+ else:
269
+ wav = wav[:clip_samples]
270
+ return wav.astype(np.float32, copy=False)
271
+ except Exception:
272
+ return np.zeros(int(CLIP_SECONDS * SAMPLE_RATE), dtype=np.float32)
273
+
274
+ def time_to_frame(t: float) -> int:
275
+ if t <= 0:
276
+ return 0
277
+ if t >= CLIP_SECONDS:
278
+ return NUM_FRAMES - 1
279
+ return max(0, min(NUM_FRAMES - 1, int(t / CLIP_SECONDS * NUM_FRAMES)))
280
+
281
+ def parse_events(obj: Dict[str, Any]) -> List[Tuple[float,float]]:
282
+ seg = obj.get("segmentation", {})
283
+ cand = seg.get("events") if isinstance(seg, dict) else None
284
+ if not isinstance(cand, list):
285
+ cand = obj.get("events", [])
286
+ out = []
287
+ for e in cand or []:
288
+ st, et = e.get("start_time"), e.get("end_time")
289
+ if isinstance(st, (int, float)) and isinstance(et, (int, float)) and et > st:
290
+ s = max(0.0, float(st))
291
+ e_ = min(CLIP_SECONDS, float(et))
292
+ if e_ > s:
293
+ out.append((s, e_))
294
+ return out
295
+
296
+ def build_labels_parity(events_sec: List[Tuple[float,float]]) -> torch.LongTensor:
297
+ ev = sorted(events_sec, key=lambda x: (x[0], x[1]))[:MAX_SEGMENTS]
298
+ frames = [(time_to_frame(s), time_to_frame(e)) for (s, e) in ev]
299
+ labels = torch.zeros((NUM_TRACKS, NUM_FRAMES), dtype=torch.long)
300
+ for i, (s, e) in enumerate(frames, start=1):
301
+ track = 0 if (i % 2 == 1) else 1
302
+ sl = labels[track, s:e+1]
303
+ bg = sl == 0
304
+ if bg.any():
305
+ sl[bg] = i
306
+ return labels
307
+
308
+ class TarOrFileDataset(Dataset):
309
+ """
310
+ Each item is either:
311
+ {"kind":"tar", "tar": Path, "a": "member.wav", "j": "member.json"} or
312
+ {"kind":"file","a_path": Path, "j_path": Path}
313
+ """
314
+ def __init__(self, items: List[Dict[str,Any]], fe):
315
+ self.items = items
316
+ self.fe = fe
317
+
318
+ def __len__(self):
319
+ return len(self.items)
320
+
321
+ def __getitem__(self, idx):
322
+ it = self.items[idx]
323
+ if it["kind"] == "tar":
324
+ obj = read_member_json(it["tar"], it["j"])
325
+ wav = read_member_audio_30s(it["tar"], it["a"])
326
+ a_name = it["a"]
327
+ else:
328
+ obj = read_file_json(it["j_path"])
329
+ wav = read_file_audio_30s(it["a_path"])
330
+ a_name = str(it["a_path"].name)
331
+ ev = parse_events(obj)
332
+ labels = build_labels_parity(ev)
333
+ feat = self.fe(wav, sampling_rate=SAMPLE_RATE, return_tensors="pt")
334
+ input_features = feat.input_features[0]
335
+ return {"x": input_features, "y": labels,
336
+ "meta": {"a": a_name, "ev": len(ev)}}
337
+
338
+ def collate_fn(batch):
339
+ x = torch.stack([b["x"] for b in batch], dim=0)
340
+ y = torch.stack([b["y"] for b in batch], dim=0)
341
+ meta = {k: [b["meta"][k] for b in batch] for k in batch[0]["meta"]}
342
+ return {"x": x, "y": y, "meta": meta}
343
+
344
+ # =========================
345
+ # ========= MODEL =========
346
+ # =========================
347
+ class WhisperOddEven(nn.Module):
348
+ def __init__(self, base_id: str, freeze_encoder: bool):
349
+ super().__init__()
350
+ resolved, is_local = _model_resolved_name(base_id)
351
+ self.whisper = WhisperModel.from_pretrained(resolved, local_files_only=is_local)
352
+
353
+ # Freeze decoder (unused)
354
+ for p in self.whisper.decoder.parameters():
355
+ p.requires_grad = False
356
+
357
+ # Train encoder
358
+ for p in self.whisper.encoder.parameters():
359
+ p.requires_grad = not freeze_encoder
360
+
361
+ d_model = self.whisper.config.d_model
362
+ hidden = max(256, d_model // 2)
363
+ self.head = nn.Sequential(
364
+ nn.Linear(d_model, hidden),
365
+ nn.GELU(),
366
+ nn.Linear(hidden, NUM_TRACKS * (MAX_SEGMENTS + 1)),
367
+ )
368
+
369
+ def forward(self, input_features: torch.FloatTensor):
370
+ enc = self.whisper.encoder(input_features=input_features).last_hidden_state # [B,1500,D]
371
+ B, T, D = enc.shape
372
+ logits = self.head(enc) # [B,T,NUM_TRACKS*(C)]
373
+ C = MAX_SEGMENTS + 1
374
+ logits = logits.view(B, T, NUM_TRACKS, C).permute(0, 2, 1, 3).contiguous()
375
+ return logits # [B,2,1500,C]
376
+
377
+ def compute_loss(logits, labels):
378
+ B, TR, T, C = logits.shape
379
+ return F.cross_entropy(
380
+ logits.view(B * TR * T, C),
381
+ labels.view(B * TR * T),
382
+ reduction="mean",
383
+ )
384
+
385
+ @torch.no_grad()
386
+ def frame_accuracy(logits, labels, include_bg=False):
387
+ pred = logits.argmax(dim=-1)
388
+ if include_bg:
389
+ correct = (pred == labels).sum().item()
390
+ total = labels.numel()
391
+ else:
392
+ mask = labels != 0
393
+ correct = (pred[mask] == labels[mask]).sum().item()
394
+ total = mask.sum().item() if mask.any() else 1
395
+ return correct / max(1, total)
396
+
397
+ # =========================
398
+ # ======= REPORTING =======
399
+ # =========================
400
+ def _plot_tracks_seconds(pred_ids: torch.Tensor, title: str) -> bytes:
401
+ secs = np.linspace(0.0, CLIP_SECONDS, NUM_FRAMES)
402
+ fig = plt.figure(figsize=(10, 2.8))
403
+ ax = plt.gca()
404
+ im = ax.imshow(
405
+ pred_ids.numpy(),
406
+ aspect="auto",
407
+ interpolation="nearest",
408
+ origin="upper",
409
+ extent=[secs[0], secs[-1], -0.5, 1.5],
410
+ )
411
+ ax.set_title(title)
412
+ ax.set_xlabel("Time (s)")
413
+ ax.set_yticks([0, 1])
414
+ ax.set_yticklabels(["odd", "even"])
415
+ cb = plt.colorbar(im, fraction=0.046, pad=0.04)
416
+ cb.set_label("Segment ID")
417
+ buf = io.BytesIO()
418
+ fig.savefig(buf, format="png", dpi=150, bbox_inches="tight")
419
+ plt.close(fig)
420
+ buf.seek(0)
421
+ return buf.read()
422
+
423
+ def _mime_for_ext(fn: str) -> str:
424
+ ext = Path(fn).suffix.lower()
425
+ if ext == ".mp3":
426
+ return "audio/mpeg"
427
+ if ext == ".wav":
428
+ return "audio/wav"
429
+ # Fallback—browsers may still play if encoded as generic octet-stream
430
+ return "audio/wav"
431
+
432
+ def write_eval_html(out_dir: Path, eval_id: str, rows: List[Dict[str, Any]]):
433
+ html = [f"""<!doctype html><html><head><meta charset="utf-8">
434
+ <style>
435
+ body{{font-family:system-ui,Segoe UI,Roboto,Arial,sans-serif;margin:20px}}
436
+ .card{{border:1px solid #ddd;border-radius:10px;padding:16px;margin:16px 0;box-shadow:0 2px 6px rgba(0,0,0,.05)}}
437
+ .grid{{display:grid;grid-template-columns:1fr 1fr;gap:12px}}
438
+ figure{{margin:0}}
439
+ figcaption{{font-size:13px;color:#555;margin-top:6px}}
440
+ audio{{width:100%;margin-top:8px}}
441
+ </style>
442
+ <title>Odd/Even Segmentation — Eval {eval_id}</title></head><body>"""]
443
+ for r in rows:
444
+ audio_html = ""
445
+ if r.get("audio_b64") and r.get("audio_mime"):
446
+ audio_html = (
447
+ '<audio controls preload="none">'
448
+ f'<source src="data:{r["audio_mime"]};base64,{r["audio_b64"]}" type="{r["audio_mime"]}">'
449
+ 'Your browser does not support the audio element.'
450
+ '</audio>'
451
+ '<small style="color:#555">Listen: original eval audio</small>'
452
+ )
453
+ html.append(f"""
454
+ <section class="card">
455
+ <h3>{r['a_name']}</h3>
456
+ {audio_html}
457
+ <div class="grid">
458
+ <figure>
459
+ <img src="data:image/png;base64,{r['png_raw']}" alt="raw">
460
+ <figcaption>RAW (avg acc vs GT: {r['acc_raw']:.3f})</figcaption>
461
+ </figure>
462
+ <figure>
463
+ <img src="data:image/png;base64,{r['png_sm']}" alt="smoothed">
464
+ <figcaption>SMOOTHED (avg acc vs GT: {r['acc_sm']:.3f})</figcaption>
465
+ </figure>
466
+ </div>
467
+ </section>""")
468
+ html.append("</body></html>")
469
+ p = out_dir / f"eval_{eval_id}.html"
470
+ try:
471
+ p.write_text("\n".join(html), encoding="utf-8")
472
+ except Exception as e:
473
+ print(f"[eval-html] failed to write {p}: {e}", flush=True)
474
+ return p
475
+
476
+ # =========================
477
+ # ========= TRAIN =========
478
+ # =========================
479
+ def unwrap(model: nn.Module) -> nn.Module:
480
+ return model.module if isinstance(model, DataParallel) else model
481
+
482
+ def main():
483
+ setup_dirs()
484
+ set_seed(SEED)
485
+
486
+ # logging
487
+ log_path = OUT_DIR / "train.log"
488
+ log_f = open(log_path, "a", buffering=1)
489
+
490
+ def log(*a):
491
+ s = " ".join(str(x) for x in a)
492
+ print(s, flush=True)
493
+ print(s, file=log_f, flush=True)
494
+
495
+ # index "other" data (original DATA_DIR)
496
+ tar_files = sorted(
497
+ set(
498
+ [p for p in DATA_DIR.rglob("*.tar") if p.is_file()]
499
+ + [p for p in DATA_DIR.rglob("*.tar.gz") if p.is_file()]
500
+ )
501
+ )
502
+ loose_pairs = index_loose_pairs(DATA_DIR)
503
+ log(f"==> Found {len(tar_files)} tarballs and {len(loose_pairs)} loose audio+json pairs in {DATA_DIR}")
504
+
505
+ other_items: List[Dict[str, Any]] = []
506
+ for tp in tar_files:
507
+ pairs = index_tar_pairs_streaming(tp)
508
+ log(f"[index] {tp.name}: {len(pairs)} pairs")
509
+ for a_m, j_m in pairs:
510
+ other_items.append({"kind": "tar", "tar": tp, "a": a_m, "j": j_m})
511
+ for a_p, j_p in loose_pairs:
512
+ other_items.append({"kind": "file", "a_path": a_p, "j_path": j_p})
513
+
514
+ log(f"[other] Total base items from DATA_DIR: {len(other_items)}")
515
+
516
+ # index Gemini data
517
+ gem_items: List[Dict[str, Any]] = []
518
+ if USE_GEMINI and GEMINI_DIR.is_dir():
519
+ raw_pairs = index_gemini_pairs(GEMINI_DIR)
520
+ log(f"[gemini] Scanning {GEMINI_DIR} -> {len(raw_pairs)} *.audio.mp3+json pairs (candidates)")
521
+ n_med = 0
522
+ n_ov_true = 0
523
+ n_ov_false = 0
524
+ n_bad_no_seg = 0
525
+
526
+ allowed_overlaps = []
527
+ if GEMINI_INCLUDE_OVERLAP_TRUE:
528
+ allowed_overlaps.append(True)
529
+ if GEMINI_INCLUDE_OVERLAP_FALSE:
530
+ allowed_overlaps.append(False)
531
+
532
+ for a_p, j_p in raw_pairs:
533
+ obj = read_file_json(j_p)
534
+ seg_dur = obj.get("segment_duration")
535
+ overlapping = obj.get("overlapping")
536
+
537
+ # Only medium
538
+ if seg_dur != GEMINI_SEGMENT_DURATION:
539
+ continue
540
+ n_med += 1
541
+
542
+ if overlapping is True:
543
+ n_ov_true += 1
544
+ if not GEMINI_INCLUDE_OVERLAP_TRUE:
545
+ continue
546
+ elif overlapping is False:
547
+ n_ov_false += 1
548
+ if not GEMINI_INCLUDE_OVERLAP_FALSE:
549
+ continue
550
+ else:
551
+ # overlapping field missing or weird -> skip
552
+ continue
553
+
554
+ # make sure segmentation/events exist
555
+ seg_block = obj.get("segmentation", {})
556
+ if not isinstance(seg_block, dict) or not isinstance(seg_block.get("events"), list):
557
+ n_bad_no_seg += 1
558
+ continue
559
+
560
+ gem_items.append({"kind": "file", "a_path": a_p, "j_path": j_p})
561
+
562
+ log(
563
+ f"[gemini] medium-candidates={n_med} "
564
+ f"(overlap true={n_ov_true}, overlap false={n_ov_false}, "
565
+ f"discarded missing/invalid seg={n_bad_no_seg})"
566
+ )
567
+ log(
568
+ f"[gemini] after filters "
569
+ f"(segment_duration='{GEMINI_SEGMENT_DURATION}', overlapping in {allowed_overlaps}) "
570
+ f"-> {len(gem_items)} items"
571
+ )
572
+ else:
573
+ if not USE_GEMINI:
574
+ log("[gemini] USE_GEMINI=0 -> Gemini bucket disabled")
575
+ else:
576
+ log(f"[gemini] Directory not found: {GEMINI_DIR} -> Gemini bucket disabled")
577
+
578
+ # build combined item list according to mixing rule
579
+ if gem_items:
580
+ N_gem = len(gem_items)
581
+ target_other = int(math.floor(N_gem * GEMINI_OTHER_RATIO + 0.5))
582
+ if target_other > len(other_items):
583
+ target_other = len(other_items)
584
+ random.shuffle(other_items)
585
+ sampled_other = other_items[:target_other]
586
+ combined_items = gem_items + sampled_other
587
+ random.shuffle(combined_items)
588
+
589
+ log(f"[mix] Gemini items: {N_gem}")
590
+ log(f"[mix] Sampling {target_other} items from other-data (ratio={GEMINI_OTHER_RATIO})")
591
+ log(f"[mix] Combined pool size (before train/val split): {len(combined_items)}")
592
+ else:
593
+ # Fallback: original behaviour (whole dataset from DATA_DIR)
594
+ combined_items = other_items
595
+ random.shuffle(combined_items)
596
+ log("[mix] WARNING: no Gemini items found -> training only on DATA_DIR mix")
597
+ log(f"[mix] Combined pool size: {len(combined_items)}")
598
+
599
+ if not combined_items:
600
+ log("[ERR] No usable audio+json pairs in final mix. Aborting.")
601
+ log_f.close()
602
+ return
603
+
604
+ # fixed-size validation set (up to VAL_FIXED_N)
605
+ val_n = min(VAL_FIXED_N, len(combined_items))
606
+ val_items = combined_items[:val_n]
607
+ train_items = combined_items[val_n:]
608
+
609
+ log(
610
+ f"[split] Final split -> Train={len(train_items)} | Val={len(val_items)} "
611
+ f"(VAL_FIXED_N={VAL_FIXED_N})"
612
+ )
613
+
614
+ # features
615
+ resolved, is_local = _model_resolved_name(HF_MODEL_ID)
616
+ fe = WhisperFeatureExtractor.from_pretrained(resolved, local_files_only=is_local)
617
+
618
+ # datasets & loaders
619
+ train_ds = TarOrFileDataset(train_items, fe)
620
+ val_ds = TarOrFileDataset(val_items, fe)
621
+
622
+ # provisional loader for batch probe
623
+ train_loader = DataLoader(
624
+ train_ds,
625
+ batch_size=1,
626
+ shuffle=True,
627
+ num_workers=NUM_WORKERS,
628
+ pin_memory=PIN_MEMORY,
629
+ collate_fn=collate_fn,
630
+ persistent_workers=NUM_WORKERS > 0,
631
+ prefetch_factor=2,
632
+ )
633
+
634
+ # model
635
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
636
+ base_model = WhisperOddEven(HF_MODEL_ID, freeze_encoder=FREEZE_ENCODER).to(device)
637
+
638
+ # DataParallel if multi-GPU
639
+ n_gpu = torch.cuda.device_count()
640
+ if n_gpu > 1:
641
+ log(f"[gpu] Using DataParallel across {n_gpu} GPUs")
642
+ model = DataParallel(base_model, device_ids=list(range(n_gpu)))
643
+ else:
644
+ model = base_model
645
+
646
+ # optim, amp
647
+ optim = torch.optim.AdamW(
648
+ (p for p in model.parameters() if p.requires_grad),
649
+ lr=LR,
650
+ weight_decay=WEIGHT_DECAY,
651
+ )
652
+ use_dtype = preferred_dtype()
653
+ amp_enabled = use_dtype in (torch.float16, torch.bfloat16)
654
+ scaler = torch.cuda.amp.GradScaler(enabled=(amp_enabled and use_dtype == torch.float16))
655
+
656
+ # ---- weights-only resume / init ----
657
+ state_path = RESUME_DIR / "trainer_state.pt"
658
+ best_state_path = RESUME_DIR / "trainer_state_best.pt"
659
+ start_epoch = 1
660
+ global_step = 0
661
+ seen_samples = 0
662
+ state_loaded = False
663
+
664
+ if RESUME_MODE != "none":
665
+ state_to_load: Optional[Path] = None
666
+ if RESUME_MODE == "best" and best_state_path.exists():
667
+ state_to_load = best_state_path
668
+ elif state_path.exists():
669
+ state_to_load = state_path
670
+
671
+ if state_to_load is not None:
672
+ try:
673
+ state = torch.load(state_to_load, map_location="cpu")
674
+ unwrap(model).load_state_dict(state["model"])
675
+ state_loaded = True
676
+ log(
677
+ f"[resume-weights] loaded model weights from {state_to_load}; "
678
+ f"optimizer/scheduler/counters RESET for new dataset"
679
+ )
680
+ except Exception as e:
681
+ log(f"[resume] failed to load {state_to_load}: {e}")
682
+
683
+ # optional weights-only init from separate file (only if no trainer_state used)
684
+ if (not state_loaded) and INIT_WEIGHTS is not None and INIT_WEIGHTS.is_file():
685
+ try:
686
+ ckpt = torch.load(INIT_WEIGHTS, map_location="cpu")
687
+ unwrap(model).load_state_dict(ckpt)
688
+ log(f"[init] loaded weights from {INIT_WEIGHTS}")
689
+ except Exception as e:
690
+ log(f"[init] failed to load INIT_WEIGHTS {INIT_WEIGHTS}: {e}")
691
+
692
+ # batch probe (single-process; DP handles scattering)
693
+ def try_batch_size(bsz_try: int) -> bool:
694
+ try:
695
+ it = iter(train_loader)
696
+ batch = next(it)
697
+ x = batch["x"].to(device, non_blocking=True).repeat(bsz_try, 1, 1)
698
+ y = batch["y"].to(device, non_blocking=True).repeat(bsz_try, 1, 1)
699
+ with torch.autocast(
700
+ device_type="cuda" if torch.cuda.is_available() else "cpu",
701
+ enabled=amp_enabled,
702
+ dtype=use_dtype,
703
+ ):
704
+ logits = model(x)
705
+ loss = compute_loss(logits, y)
706
+ if scaler.is_enabled():
707
+ scaler.scale(loss).backward()
708
+ else:
709
+ loss.backward()
710
+ optim.zero_grad(set_to_none=True)
711
+ if torch.cuda.is_available():
712
+ torch.cuda.synchronize()
713
+ return True
714
+ except RuntimeError as e:
715
+ if "out of memory" in str(e).lower():
716
+ if torch.cuda.is_available():
717
+ torch.cuda.empty_cache()
718
+ return False
719
+ raise
720
+ finally:
721
+ optim.zero_grad(set_to_none=True)
722
+
723
+ bsz = max(1, BATCH_SIZE)
724
+ if ADAPTIVE_BSZ:
725
+ log(f"[bsz] probing starting at {bsz} (cap={MAX_BSZ_CAP})")
726
+ if try_batch_size(bsz):
727
+ step = max(4, bsz // 2)
728
+ while True:
729
+ nxt = bsz + step
730
+ if MAX_BSZ_CAP and nxt > MAX_BSZ_CAP:
731
+ break
732
+ ok = try_batch_size(nxt)
733
+ if not ok:
734
+ break
735
+ bsz = nxt
736
+ step = max(4, step)
737
+ log(f"[bsz] increased to {bsz}")
738
+ else:
739
+ while bsz > 1 and not try_batch_size(bsz):
740
+ bsz = max(1, bsz // 2)
741
+ if bsz == 1:
742
+ log("[bsz] fell back to 1")
743
+ log(f"[bsz] final batch size = {bsz}")
744
+
745
+ # rebuild loader with final batch size (shuffles each epoch)
746
+ train_loader = DataLoader(
747
+ train_ds,
748
+ batch_size=bsz,
749
+ shuffle=True,
750
+ num_workers=NUM_WORKERS,
751
+ pin_memory=PIN_MEMORY,
752
+ collate_fn=collate_fn,
753
+ persistent_workers=NUM_WORKERS > 0,
754
+ prefetch_factor=2,
755
+ )
756
+
757
+ # scheduler, now with correct steps_per_epoch for this loader
758
+ steps_per_epoch = max(1, len(train_loader))
759
+ total_steps = max(1, EPOCHS * steps_per_epoch)
760
+ warmup = max(1, int(WARMUP_RATIO * total_steps))
761
+ sched = (
762
+ get_cosine_schedule_with_warmup(optim, warmup, total_steps)
763
+ if SCHEDULER == "cosine"
764
+ else get_linear_schedule_with_warmup(optim, warmup, total_steps)
765
+ )
766
+
767
+ # ETA helpers
768
+ ema_rate = None
769
+ total_samples = len(train_ds) * EPOCHS
770
+
771
+ def format_eta(secs: float) -> str:
772
+ secs = max(0.0, secs)
773
+ h = int(secs // 3600)
774
+ m = int((secs % 3600) // 60)
775
+ s = int(secs % 60)
776
+ return f"{h:02d}:{m:02d}:{s:02d}"
777
+
778
+ # helper to fetch original audio bytes for HTML embedding
779
+ def _audio_bytes_for_eval_item(ds: TarOrFileDataset, idx: int) -> Tuple[Optional[bytes], str, str]:
780
+ it = ds.items[idx]
781
+ if it["kind"] == "tar":
782
+ a_name = it["a"]
783
+ b = _safe_extract_bytes(it["tar"], a_name)
784
+ mime = _mime_for_ext(a_name)
785
+ disp = f"{Path(it['tar']).name} :: {a_name}"
786
+ return b, disp, mime
787
+ else:
788
+ p = it["a_path"]
789
+ try:
790
+ b = p.read_bytes()
791
+ except Exception:
792
+ b = None
793
+ mime = _mime_for_ext(str(p))
794
+ return b, p.name, mime
795
+
796
+ @torch.no_grad()
797
+ def evaluate(tag: str, n=400):
798
+ nonlocal ema_rate
799
+ unwrap(model).eval()
800
+ idx = list(range(len(val_ds)))
801
+ random.shuffle(idx)
802
+ sub = idx[: min(n, len(val_ds))]
803
+ tot_loss = 0.0
804
+ tot_acc = 0.0
805
+ rows: List[Dict[str, Any]] = []
806
+
807
+ for i in sub:
808
+ item = val_ds[i]
809
+ x_cpu = item["x"].unsqueeze(0)
810
+ y_cpu = item["y"].unsqueeze(0)
811
+ x = x_cpu.to(device, non_blocking=True)
812
+ y = y_cpu.to(device, non_blocking=True)
813
+ with torch.autocast(
814
+ device_type="cuda" if torch.cuda.is_available() else "cpu",
815
+ enabled=amp_enabled,
816
+ dtype=use_dtype,
817
+ ):
818
+ logits = unwrap(model)(x)
819
+ loss = compute_loss(logits, y).item()
820
+ acc = frame_accuracy(logits, y, include_bg=INCLUDE_BG_IN_ACC)
821
+ tot_loss += loss
822
+ tot_acc += acc
823
+
824
+ if len(rows) < HTML_TOP_N:
825
+ raw_ids = logits.argmax(dim=-1).squeeze(0).cpu()
826
+ sm_logits = F.avg_pool1d(
827
+ logits.permute(0, 1, 3, 2)
828
+ .contiguous()
829
+ .view(1, NUM_TRACKS * (MAX_SEGMENTS + 1), NUM_FRAMES),
830
+ kernel_size=9,
831
+ stride=1,
832
+ padding=4,
833
+ ).view(1, NUM_TRACKS, MAX_SEGMENTS + 1, NUM_FRAMES)
834
+ sm_logits = sm_logits.permute(0, 1, 3, 2).contiguous()
835
+ sm_ids = sm_logits.argmax(dim=-1).squeeze(0).cpu()
836
+
837
+ def acc_ignore_bg(pcpu: torch.Tensor, gtcpu: torch.Tensor) -> float:
838
+ m = gtcpu != 0
839
+ tot_ = int(m.sum().item())
840
+ if tot_ == 0:
841
+ return 0.0
842
+ return float((pcpu[m] == gtcpu[m]).sum().item()) / tot_
843
+
844
+ acc_raw = (
845
+ acc_ignore_bg(raw_ids[0], y_cpu[0, 0])
846
+ + acc_ignore_bg(raw_ids[1], y_cpu[0, 1])
847
+ ) / 2.0
848
+ acc_sm = (
849
+ acc_ignore_bg(sm_ids[0], y_cpu[0, 0])
850
+ + acc_ignore_bg(sm_ids[1], y_cpu[0, 1])
851
+ ) / 2.0
852
+
853
+ png_raw = base64.b64encode(
854
+ _plot_tracks_seconds(raw_ids, f"RAW — {i}")
855
+ ).decode("ascii")
856
+ png_sm = base64.b64encode(
857
+ _plot_tracks_seconds(sm_ids, f"SMOOTHED — {i}")
858
+ ).decode("ascii")
859
+
860
+ a_bytes, disp_name, mime = _audio_bytes_for_eval_item(val_ds, i)
861
+ audio_b64 = base64.b64encode(a_bytes).decode("ascii") if a_bytes else None
862
+
863
+ rows.append(
864
+ {
865
+ "a_name": disp_name,
866
+ "png_raw": png_raw,
867
+ "png_sm": png_sm,
868
+ "acc_raw": acc_raw,
869
+ "acc_sm": acc_sm,
870
+ "audio_b64": audio_b64,
871
+ "audio_mime": mime,
872
+ }
873
+ )
874
+
875
+ loss_avg = tot_loss / max(1, len(sub))
876
+ acc_avg = tot_acc / max(1, len(sub))
877
+
878
+ remaining = max(0.0, (total_samples - seen_samples) / max(1e-6, (ema_rate or 1.0)))
879
+ eta_str = format_eta(remaining)
880
+ html_path = write_eval_html(
881
+ OUT_DIR, f"{tag}_eta{eta_str.replace(':', '-')}", rows
882
+ )
883
+ log(f"[eval:{tag}] loss {loss_avg:.4f} acc {acc_avg:.4f} on {len(sub)} samples ? {html_path}")
884
+ unwrap(model).train()
885
+ return loss_avg
886
+
887
+ best_val = float("inf")
888
+ first_eval_threshold = EVAL_FIRST
889
+ if seen_samples >= first_eval_threshold:
890
+ first_eval_threshold = -1
891
+ periodic_eval_every = EVAL_EVERY if EVAL_EVERY > 0 else 0
892
+
893
+ unwrap(model).train()
894
+ for ep in range(start_epoch, EPOCHS + 1):
895
+ ep_t0 = time.time()
896
+ for step, batch in enumerate(train_loader, start=1):
897
+ t0 = time.time()
898
+ x = batch["x"].to(device, non_blocking=True)
899
+ y = batch["y"].to(device, non_blocking=True)
900
+
901
+ with torch.autocast(
902
+ device_type="cuda" if torch.cuda.is_available() else "cpu",
903
+ enabled=amp_enabled,
904
+ dtype=use_dtype,
905
+ ):
906
+ logits = model(x)
907
+ loss = compute_loss(logits, y)
908
+
909
+ if scaler.is_enabled():
910
+ scaler.scale(loss).backward()
911
+ scaler.unscale_(optim)
912
+ else:
913
+ loss.backward()
914
+
915
+ torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP_NORM)
916
+ if scaler.is_enabled():
917
+ scaler.step(optim)
918
+ scaler.update()
919
+ else:
920
+ optim.step()
921
+ optim.zero_grad(set_to_none=True)
922
+ sched.step()
923
+
924
+ global_step += 1
925
+ seen_samples += x.size(0)
926
+
927
+ dt = max(1e-6, time.time() - t0)
928
+ rate = x.size(0) / dt
929
+ ema_rate = rate if (ema_rate is None) else (0.05 * rate + 0.95 * ema_rate)
930
+ remaining = max(
931
+ 0.0,
932
+ (len(train_ds) * EPOCHS - seen_samples) / max(1e-6, ema_rate or 1.0),
933
+ )
934
+
935
+ if global_step % LOG_EVERY == 0:
936
+ acc_now = frame_accuracy(logits, y, include_bg=INCLUDE_BG_IN_ACC)
937
+ log(
938
+ f"[train] ep {ep} step {global_step} loss {loss.item():.4f} acc {acc_now:.4f} "
939
+ f"lr {sched.get_last_lr()[0]:.2e} seen {seen_samples}/{len(train_ds)*EPOCHS} "
940
+ f"rate {ema_rate:.1f} samp/s ETA {format_eta(remaining)}"
941
+ )
942
+
943
+ # eval schedule
944
+ do_eval = False
945
+ prev_seen = seen_samples - x.size(0)
946
+ if first_eval_threshold != -1 and first_eval_threshold > 0:
947
+ if seen_samples >= first_eval_threshold and prev_seen < first_eval_threshold:
948
+ do_eval = True
949
+ first_eval_threshold = -1
950
+ elif periodic_eval_every > 0:
951
+ prev_bucket = prev_seen // periodic_eval_every
952
+ now_bucket = seen_samples // periodic_eval_every
953
+ if now_bucket != prev_bucket and now_bucket > 0:
954
+ do_eval = True
955
+
956
+ if do_eval:
957
+ val_loss = evaluate(tag=f"gstep{global_step}_seen{seen_samples}")
958
+ # save "latest" trainer state
959
+ try:
960
+ torch.save(
961
+ {
962
+ "epoch": ep,
963
+ "global_step": global_step,
964
+ "seen_samples": seen_samples,
965
+ "model": unwrap(model).state_dict(),
966
+ "optim": optim.state_dict(),
967
+ "sched": sched.state_dict(),
968
+ "scaler": scaler.state_dict() if scaler.is_enabled() else {},
969
+ },
970
+ state_path,
971
+ )
972
+ log(f"[save] trainer state ? {state_path}")
973
+ except Exception as e:
974
+ log(f"[save] failed to write trainer state {state_path}: {e}")
975
+ # save best
976
+ if val_loss is not None and val_loss < best_val:
977
+ best_val = val_loss
978
+ try:
979
+ torch.save(unwrap(model).state_dict(), OUT_DIR / "model_best.pt")
980
+ torch.save(
981
+ unwrap(model).whisper.encoder.state_dict(),
982
+ OUT_DIR / "encoder_best.bin",
983
+ )
984
+ torch.save(
985
+ {
986
+ "epoch": ep,
987
+ "global_step": global_step,
988
+ "seen_samples": seen_samples,
989
+ "model": unwrap(model).state_dict(),
990
+ "optim": optim.state_dict(),
991
+ "sched": sched.state_dict(),
992
+ "scaler": scaler.state_dict()
993
+ if scaler.is_enabled()
994
+ else {},
995
+ },
996
+ best_state_path,
997
+ )
998
+ log(
999
+ f"[save] new BEST (eval) ? {OUT_DIR/'model_best.pt'} "
1000
+ f"(state: {best_state_path})"
1001
+ )
1002
+ except Exception as e:
1003
+ log(f"[save] failed to write BEST checkpoint: {e}")
1004
+
1005
+ # end epoch
1006
+ val_loss = evaluate(tag=f"epoch{ep}_end")
1007
+ try:
1008
+ torch.save(
1009
+ {
1010
+ "epoch": ep + 1,
1011
+ "global_step": global_step,
1012
+ "seen_samples": seen_samples,
1013
+ "model": unwrap(model).state_dict(),
1014
+ "optim": optim.state_dict(),
1015
+ "sched": sched.state_dict(),
1016
+ "scaler": scaler.state_dict() if scaler.is_enabled() else {},
1017
+ },
1018
+ state_path,
1019
+ )
1020
+ log(f"[save] trainer state (epoch end) ? {state_path}")
1021
+ except Exception as e:
1022
+ log(f"[save] failed to write trainer state (epoch end) {state_path}: {e}")
1023
+
1024
+ if val_loss is not None and val_loss < best_val:
1025
+ best_val = val_loss
1026
+ try:
1027
+ torch.save(unwrap(model).state_dict(), OUT_DIR / "model_best.pt")
1028
+ torch.save(
1029
+ unwrap(model).whisper.encoder.state_dict(),
1030
+ OUT_DIR / "encoder_best.bin",
1031
+ )
1032
+ torch.save(
1033
+ {
1034
+ "epoch": ep + 1,
1035
+ "global_step": global_step,
1036
+ "seen_samples": seen_samples,
1037
+ "model": unwrap(model).state_dict(),
1038
+ "optim": optim.state_dict(),
1039
+ "sched": sched.state_dict(),
1040
+ "scaler": scaler.state_dict() if scaler.is_enabled() else {},
1041
+ },
1042
+ best_state_path,
1043
+ )
1044
+ log(
1045
+ f"[save] new BEST (epoch) ? {OUT_DIR/'model_best.pt'} "
1046
+ f"(state: {best_state_path})"
1047
+ )
1048
+ except Exception as e:
1049
+ log(f"[save] failed to write BEST checkpoint (epoch end): {e}")
1050
+
1051
+ log(f"[epoch] {ep}/{EPOCHS} done in {time.time() - ep_t0:.1f}s")
1052
+
1053
+ log("\n[done] Training complete.")
1054
+ log_f.close()
1055
+
1056
+ if __name__ == "__main__":
1057
+ try:
1058
+ main()
1059
+ except Exception:
1060
+ print("[FATAL]\n", traceback.format_exc(), flush=True)
1061
+ raise