Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -87,196 +87,14 @@ Xs_te = scaler.transform(X_te).astype(np.float32)
|
|
| 87 |
import joblib
|
| 88 |
joblib.dump(scaler, "mtl_scaler.joblib")
|
| 89 |
|
| 90 |
-
#-- 6) tensors --#
|
| 91 |
-
|
| 92 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 93 |
-
def tt(a, dtype=torch.float32): return torch.from_numpy(a).to(device).to(dtype)
|
| 94 |
-
|
| 95 |
-
Xt_tr, Xt_va, Xt_te = tt(Xs_tr), tt(Xs_va), tt(Xs_te)
|
| 96 |
-
yI_tr, yI_va, yI_te = tt(I_tr), tt(I_va), tt(I_te)
|
| 97 |
-
yD_tr, yD_va, yD_te = tt(D_tr), tt(D_va), tt(D_te)
|
| 98 |
-
|
| 99 |
-
if H_tr_all is not None:
|
| 100 |
-
H_tr_all_t = tt(H_tr_all)
|
| 101 |
-
H_va_all_t = tt(H_va_all)
|
| 102 |
-
H_te_all_t = tt(H_te_all)
|
| 103 |
-
mH_tr_t = torch.from_numpy(mH_tr.astype(bool)).to(device)
|
| 104 |
-
mH_va_t = torch.from_numpy(mH_va.astype(bool)).to(device)
|
| 105 |
-
mH_te_t = torch.from_numpy(mH_te.astype(bool)).to(device)
|
| 106 |
-
else:
|
| 107 |
-
H_tr_all_t = H_va_all_t = H_te_all_t = None
|
| 108 |
-
mH_tr_t = mH_va_t = mH_te_t = None
|
| 109 |
-
|
| 110 |
-
# ----------7) Model: Multi-Task MLP (shared trunk + 3 heads)----------
|
| 111 |
-
class MTLNet(nn.Module):
|
| 112 |
-
def __init__(self, d_in=384, d_shared=384):
|
| 113 |
-
super().__init__()
|
| 114 |
-
self.trunk = nn.Sequential(
|
| 115 |
-
nn.Linear(d_in, d_shared), nn.ReLU(), nn.Dropout(0.2),
|
| 116 |
-
nn.Linear(d_shared, 192), nn.ReLU(), nn.Dropout(0.1)
|
| 117 |
-
)
|
| 118 |
-
self.head_imp = nn.Linear(192, 1) # importance raw (clamped at inference)
|
| 119 |
-
self.head_dur = nn.Linear(192, 1) # log-hours
|
| 120 |
-
self.head_hor = nn.Linear(192, 1) # log-days
|
| 121 |
-
|
| 122 |
-
# learnable uncertainty weights (auto-balance task losses)
|
| 123 |
-
self.log_sigma_imp = nn.Parameter(torch.tensor(0.0))
|
| 124 |
-
self.log_sigma_dur = nn.Parameter(torch.tensor(0.0))
|
| 125 |
-
self.log_sigma_hor = nn.Parameter(torch.tensor(0.0))
|
| 126 |
-
|
| 127 |
-
def forward(self, x):
|
| 128 |
-
h = self.trunk(x)
|
| 129 |
-
return (
|
| 130 |
-
self.head_imp(h).squeeze(-1),
|
| 131 |
-
self.head_dur(h).squeeze(-1),
|
| 132 |
-
self.head_hor(h).squeeze(-1),
|
| 133 |
-
)
|
| 134 |
-
|
| 135 |
-
def multitask_loss(self, xb, yI, yD, yH=None, mH=None):
|
| 136 |
-
rI, rD, rH = self(xb)
|
| 137 |
-
|
| 138 |
-
# importance: SmoothL1 on raw scale
|
| 139 |
-
l_imp = nn.SmoothL1Loss()(rI, yI)
|
| 140 |
-
# duration: MSE on log1p(hours)
|
| 141 |
-
l_dur = nn.MSELoss()(rD, torch.log1p(yD))
|
| 142 |
-
|
| 143 |
-
loss = torch.exp(-self.log_sigma_imp)*l_imp + self.log_sigma_imp \
|
| 144 |
-
+ torch.exp(-self.log_sigma_dur)*l_dur + self.log_sigma_dur
|
| 145 |
-
|
| 146 |
-
l_hor_val = None
|
| 147 |
-
if (yH is not None) and (mH is not None) and mH.any():
|
| 148 |
-
# horizon: only where label exists (mask True), MSE on log1p(days)
|
| 149 |
-
l_hor = nn.MSELoss()(rH[mH], torch.log1p(yH[mH]))
|
| 150 |
-
loss = loss + torch.exp(-self.log_sigma_hor)*l_hor + self.log_sigma_hor
|
| 151 |
-
l_hor_val = float(l_hor.item())
|
| 152 |
-
return loss, (float(l_imp.item()), float(l_dur.item()), l_hor_val)
|
| 153 |
-
|
| 154 |
-
net = MTLNet(d_in=Xt_tr.shape[1]).to(device)
|
| 155 |
-
opt = optim.AdamW(net.parameters(), lr=1e-3, weight_decay=1e-4)
|
| 156 |
-
|
| 157 |
-
# ----------8) Helper Functions for NN ----------
|
| 158 |
-
def predict_heads(Xt):
|
| 159 |
-
net.eval()
|
| 160 |
-
with torch.no_grad():
|
| 161 |
-
rI, rD, rH = net(Xt)
|
| 162 |
-
I = torch.clamp(rI, 1.0, 10.0) # importance 1..10
|
| 163 |
-
Hh = torch.expm1(rD).clamp(0.25, 12.0) # hours
|
| 164 |
-
Hd = torch.expm1(rH).clamp(0.0, 30.0) # days; 0 allowed (today)
|
| 165 |
-
return I, Hh, Hd
|
| 166 |
-
|
| 167 |
-
def eval_block(Xt, yI_true, yD_true, yH_true=None, mH=None):
|
| 168 |
-
I, Hh, Hd = predict_heads(Xt)
|
| 169 |
-
I_np, H_np, Hd_np = I.cpu().numpy(), Hh.cpu().numpy(), Hd.cpu().numpy()
|
| 170 |
-
maeI = mean_absolute_error(yI_true.cpu().numpy(), I_np)
|
| 171 |
-
maeD = mean_absolute_error(yD_true.cpu().numpy(), H_np)
|
| 172 |
-
rhoI = spearmanr(yI_true.cpu().numpy(), I_np).correlation if len(I_np) > 1 else float('nan')
|
| 173 |
-
rhoD = spearmanr(yD_true.cpu().numpy(), H_np).correlation if len(H_np) > 1 else float('nan')
|
| 174 |
-
out = {"maeI":maeI, "maeD":maeD, "rhoI":rhoI, "rhoD":rhoD}
|
| 175 |
-
if (yH_true is not None) and (mH is not None) and mH.any():
|
| 176 |
-
yH_np, mH_np = yH_true.cpu().numpy(), mH.cpu().numpy().astype(bool)
|
| 177 |
-
maeH = mean_absolute_error(yH_np[mH_np], Hd_np[mH_np])
|
| 178 |
-
rhoH = spearmanr(yH_np[mH_np], Hd_np[mH_np]).correlation if mH_np.sum()>1 else float('nan')
|
| 179 |
-
out.update({"maeH":maeH, "rhoH":rhoH})
|
| 180 |
-
return out
|
| 181 |
-
|
| 182 |
-
# ------------------------------------------------------------
|
| 183 |
-
# 6) Train (mini-batch) with LR schedule, AMP, clipping, early stop
|
| 184 |
-
# ------------------------------------------------------------
|
| 185 |
-
EPOCHS = 100
|
| 186 |
-
BATCH = 32
|
| 187 |
-
best_val = float("inf")
|
| 188 |
-
patience = 20
|
| 189 |
-
bad = 0
|
| 190 |
-
|
| 191 |
-
# (re)define optimizer if you like a lower LR for longer cosine cycles
|
| 192 |
-
opt = optim.AdamW(net.parameters(), lr=3e-4, weight_decay=2e-4)
|
| 193 |
-
|
| 194 |
-
# Cosine schedule with warm restarts (restart every ~40 epochs; doubles thereafter)
|
| 195 |
-
sched = optim.lr_scheduler.CosineAnnealingWarmRestarts(opt, T_0=40, T_mult=2, eta_min=1e-6)
|
| 196 |
-
|
| 197 |
-
# optional: also step down on plateaus (acts as a safety net)
|
| 198 |
-
plateau = optim.lr_scheduler.ReduceLROnPlateau(opt, mode="min", factor=0.5, patience=6, min_lr=1e-6)
|
| 199 |
-
|
| 200 |
-
scaler = torch.cuda.amp.GradScaler(enabled=torch.cuda.is_available())
|
| 201 |
-
|
| 202 |
-
n_tr = Xt_tr.shape[0]
|
| 203 |
-
for ep in range(1, EPOCHS + 1):
|
| 204 |
-
net.train()
|
| 205 |
-
order = torch.randperm(n_tr, device=device)
|
| 206 |
-
tot_loss = 0.0
|
| 207 |
-
|
| 208 |
-
for s in range(0, n_tr, BATCH):
|
| 209 |
-
e = min(s + BATCH, n_tr)
|
| 210 |
-
idx = order[s:e]
|
| 211 |
-
xb, yi, yd = Xt_tr[idx], yI_tr[idx], yD_tr[idx]
|
| 212 |
-
|
| 213 |
-
if H_tr_all_t is not None:
|
| 214 |
-
yh, mh = H_tr_all_t[idx], mH_tr_t[idx]
|
| 215 |
-
else:
|
| 216 |
-
yh = mh = None
|
| 217 |
-
|
| 218 |
-
opt.zero_grad(set_to_none=True)
|
| 219 |
-
|
| 220 |
-
use_amp = torch.cuda.is_available()
|
| 221 |
-
amp_device = "cuda" if use_amp else "cpu" # amp on CPU is a no-op fallback
|
| 222 |
-
|
| 223 |
-
with torch.amp.autocast(device_type=amp_device, enabled=use_amp):
|
| 224 |
-
loss, _ = net.multitask_loss(xb, yi, yd, yh, mh)
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
scaler.scale(loss).backward()
|
| 228 |
-
# gradient clipping for stability
|
| 229 |
-
scaler.unscale_(opt)
|
| 230 |
-
torch.nn.utils.clip_grad_norm_(net.parameters(), max_norm=1.0)
|
| 231 |
-
scaler.step(opt)
|
| 232 |
-
scaler.update()
|
| 233 |
-
|
| 234 |
-
tot_loss += float(loss.item())
|
| 235 |
-
|
| 236 |
-
# ---- validation ----
|
| 237 |
-
stats_va = eval_block(
|
| 238 |
-
Xt_va, yI_va, yD_va,
|
| 239 |
-
(H_va_all_t if H_va_all_t is not None else None),
|
| 240 |
-
(mH_va_t if mH_va_t is not None else None)
|
| 241 |
-
)
|
| 242 |
-
total_val = stats_va["maeI"] + stats_va["maeD"] + (stats_va.get("maeH", 0.0))
|
| 243 |
-
|
| 244 |
-
# step cosine scheduler every epoch
|
| 245 |
-
sched.step(ep)
|
| 246 |
-
# also step plateau on the combined val metric
|
| 247 |
-
plateau.step(total_val)
|
| 248 |
-
|
| 249 |
-
if ep % 5 == 0:
|
| 250 |
-
lr_now = opt.param_groups[0]["lr"]
|
| 251 |
-
extra = f" hor={stats_va.get('maeH', float('nan')):.3f}" if "maeH" in stats_va else ""
|
| 252 |
-
|
| 253 |
-
# ---- early stopping on summed MAE ----
|
| 254 |
-
if total_val < best_val - 1e-4:
|
| 255 |
-
best_val = total_val
|
| 256 |
-
bad = 0
|
| 257 |
-
torch.save(net.state_dict(), "mtl_net.pt")
|
| 258 |
-
else:
|
| 259 |
-
bad += 1
|
| 260 |
-
if bad >= patience:
|
| 261 |
-
break
|
| 262 |
-
|
| 263 |
-
# ---- TEST with best checkpoint ----
|
| 264 |
-
net.load_state_dict(torch.load("mtl_net.pt", map_location=device))
|
| 265 |
-
stats_te = eval_block(
|
| 266 |
-
Xt_te, yI_te, yD_te,
|
| 267 |
-
(H_te_all_t if H_te_all_t is not None else None),
|
| 268 |
-
(mH_te_t if mH_te_t is not None else None)
|
| 269 |
-
)
|
| 270 |
-
|
| 271 |
-
# ============================================================
|
| 272 |
-
# 0) Setup: device, seeds, helper
|
| 273 |
-
# ============================================================
|
| 274 |
import numpy as np, random
|
| 275 |
import torch, torch.nn as nn, torch.optim as optim
|
| 276 |
from scipy.stats import spearmanr
|
| 277 |
from sklearn.preprocessing import StandardScaler
|
| 278 |
from sklearn.metrics import mean_absolute_error
|
| 279 |
|
|
|
|
|
|
|
| 280 |
torch.manual_seed(42); np.random.seed(42); random.seed(42)
|
| 281 |
if torch.cuda.is_available(): torch.cuda.manual_seed_all(42)
|
| 282 |
torch.backends.cudnn.benchmark = True
|
|
@@ -291,58 +109,19 @@ def safe_spearman(a, b):
|
|
| 291 |
r = spearmanr(a, b).correlation
|
| 292 |
return float('nan') if r is None else float(r)
|
| 293 |
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
mH_all = mH_tr if ('mH_tr' in globals() and mH_tr is not None) else None # mask for horizon label
|
| 304 |
-
|
| 305 |
-
N = X_all.shape[0]
|
| 306 |
-
rng = np.random.RandomState(42)
|
| 307 |
-
idx = np.arange(N)
|
| 308 |
-
rng.shuffle(idx)
|
| 309 |
-
|
| 310 |
-
# Adjust split sizes to ensure val and test sets are not empty
|
| 311 |
-
n_train = 280 # Reduced training set size
|
| 312 |
-
n_rem = N - n_train
|
| 313 |
-
n_val = n_rem // 2
|
| 314 |
-
n_test = n_rem - n_val
|
| 315 |
-
|
| 316 |
-
i_tr = idx[:n_train]
|
| 317 |
-
i_va = idx[n_train:n_train+n_val]
|
| 318 |
-
i_te = idx[n_train+n_val:]
|
| 319 |
-
|
| 320 |
-
# ============================================================
|
| 321 |
-
# 2) Standardize X using ONLY the train set, then tensorize
|
| 322 |
-
# ============================================================
|
| 323 |
-
scaler = StandardScaler().fit(X_all[i_tr])
|
| 324 |
-
Xn = scaler.transform(X_all)
|
| 325 |
-
|
| 326 |
-
Xt_tr, Xt_va, Xt_te = tt(Xn[i_tr]), tt(Xn[i_va]), tt(Xn[i_te])
|
| 327 |
-
yI_tr, yI_va, yI_te = tt(I_all[i_tr]), tt(I_all[i_va]), tt(I_all[i_te])
|
| 328 |
-
yD_tr, yD_va, yD_te = tt(D_all[i_tr]), tt(D_all[i_va]), tt(D_all[i_te])
|
| 329 |
-
|
| 330 |
-
if H_all is not None:
|
| 331 |
-
H_tr_all_t = tt(H_all[i_tr])
|
| 332 |
-
H_va_all_t = tt(H_all[i_va])
|
| 333 |
-
H_te_all_t = tt(H_all[i_te])
|
| 334 |
-
# masks as bool tensors on device
|
| 335 |
-
mH_tr_t = torch.from_numpy(mH_all[i_tr].astype(bool)).to(device)
|
| 336 |
-
mH_va_t = torch.from_numpy(mH_all[i_va].astype(bool)).to(device)
|
| 337 |
-
mH_te_t = torch.from_numpy(mH_all[i_te].astype(bool)).to(device)
|
| 338 |
-
else:
|
| 339 |
-
H_tr_all_t = H_va_all_t = H_te_all_t = None
|
| 340 |
-
mH_tr_t = mH_va_t = mH_te_t = None
|
| 341 |
|
| 342 |
-
# ============================================================
|
| 343 |
# 3) Model: Multi-Task MLP (shared trunk + 3 heads)
|
| 344 |
# Slightly wider trunk; textbook uncertainty weighting (0.5 factor)
|
| 345 |
-
|
| 346 |
class MTLNet(nn.Module):
|
| 347 |
def __init__(self, d_in, d_hid=512):
|
| 348 |
super().__init__()
|
|
@@ -390,9 +169,8 @@ class MTLNet(nn.Module):
|
|
| 390 |
|
| 391 |
net = MTLNet(d_in=Xt_tr.shape[1]).to(device)
|
| 392 |
|
| 393 |
-
# ============================================================
|
| 394 |
# 4) Prediction + Eval helpers
|
| 395 |
-
|
| 396 |
@torch.no_grad()
|
| 397 |
def predict_heads(Xt):
|
| 398 |
net.eval()
|
|
@@ -421,9 +199,8 @@ def eval_block(Xt, yI_true, yD_true, yH_true=None, mH=None):
|
|
| 421 |
out.update({"maeH": maeH, "rhoH": rhoH})
|
| 422 |
return out
|
| 423 |
|
| 424 |
-
# ============================================================
|
| 425 |
# 5) Train loop with per-batch cosine, AMP (new API), early stop
|
| 426 |
-
|
| 427 |
EPOCHS = 120
|
| 428 |
BATCH = 64
|
| 429 |
best_val = float("inf")
|
|
@@ -484,9 +261,7 @@ for ep in range(1, EPOCHS + 1):
|
|
| 484 |
if bad >= patience:
|
| 485 |
break
|
| 486 |
|
| 487 |
-
# ============================================================
|
| 488 |
# 6) TEST with best checkpoint + final confirmation
|
| 489 |
-
# ============================================================
|
| 490 |
|
| 491 |
net.load_state_dict(torch.load("mtl_net.pt", map_location=device))
|
| 492 |
stats_te = eval_block(
|
|
@@ -673,8 +448,6 @@ def reorder_tasks(tasks_string, user_due_iso=None):
|
|
| 673 |
|
| 674 |
return task_lines, due_lines_out, duration_lines, checkbox_update
|
| 675 |
|
| 676 |
-
import re
|
| 677 |
-
|
| 678 |
import gradio as gr # For building the interface
|
| 679 |
|
| 680 |
with gr.Blocks() as demo:
|
|
|
|
| 87 |
import joblib
|
| 88 |
joblib.dump(scaler, "mtl_scaler.joblib")
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
import numpy as np, random
|
| 91 |
import torch, torch.nn as nn, torch.optim as optim
|
| 92 |
from scipy.stats import spearmanr
|
| 93 |
from sklearn.preprocessing import StandardScaler
|
| 94 |
from sklearn.metrics import mean_absolute_error
|
| 95 |
|
| 96 |
+
# 0) Setup: device, seeds, helper
|
| 97 |
+
|
| 98 |
torch.manual_seed(42); np.random.seed(42); random.seed(42)
|
| 99 |
if torch.cuda.is_available(): torch.cuda.manual_seed_all(42)
|
| 100 |
torch.backends.cudnn.benchmark = True
|
|
|
|
| 109 |
r = spearmanr(a, b).correlation
|
| 110 |
return float('nan') if r is None else float(r)
|
| 111 |
|
| 112 |
+
Xt_tr, Xt_va, Xt_te = tt(Xs_tr), tt(Xs_va), tt(Xs_te)
|
| 113 |
+
yI_tr, yI_va, yI_te = tt(I_tr), tt(I_va), tt(I_te)
|
| 114 |
+
yD_tr, yD_va, yD_te = tt(D_tr), tt(D_va), tt(D_te)
|
| 115 |
+
|
| 116 |
+
if y_hor is not None:
|
| 117 |
+
H_tr_all_t, H_va_all_t, H_te_all_t = tt(H_tr_all), tt(H_va_all), tt(H_te_all)
|
| 118 |
+
mH_tr_t = torch.from_numpy(mH_tr.astype(bool)).to(device)
|
| 119 |
+
mH_va_t = torch.from_numpy(mH_va.astype(bool)).to(device)
|
| 120 |
+
mH_te_t = torch.from_numpy(mH_te.astype(bool)).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
|
|
|
|
| 122 |
# 3) Model: Multi-Task MLP (shared trunk + 3 heads)
|
| 123 |
# Slightly wider trunk; textbook uncertainty weighting (0.5 factor)
|
| 124 |
+
|
| 125 |
class MTLNet(nn.Module):
|
| 126 |
def __init__(self, d_in, d_hid=512):
|
| 127 |
super().__init__()
|
|
|
|
| 169 |
|
| 170 |
net = MTLNet(d_in=Xt_tr.shape[1]).to(device)
|
| 171 |
|
|
|
|
| 172 |
# 4) Prediction + Eval helpers
|
| 173 |
+
|
| 174 |
@torch.no_grad()
|
| 175 |
def predict_heads(Xt):
|
| 176 |
net.eval()
|
|
|
|
| 199 |
out.update({"maeH": maeH, "rhoH": rhoH})
|
| 200 |
return out
|
| 201 |
|
|
|
|
| 202 |
# 5) Train loop with per-batch cosine, AMP (new API), early stop
|
| 203 |
+
|
| 204 |
EPOCHS = 120
|
| 205 |
BATCH = 64
|
| 206 |
best_val = float("inf")
|
|
|
|
| 261 |
if bad >= patience:
|
| 262 |
break
|
| 263 |
|
|
|
|
| 264 |
# 6) TEST with best checkpoint + final confirmation
|
|
|
|
| 265 |
|
| 266 |
net.load_state_dict(torch.load("mtl_net.pt", map_location=device))
|
| 267 |
stats_te = eval_block(
|
|
|
|
| 448 |
|
| 449 |
return task_lines, due_lines_out, duration_lines, checkbox_update
|
| 450 |
|
|
|
|
|
|
|
| 451 |
import gradio as gr # For building the interface
|
| 452 |
|
| 453 |
with gr.Blocks() as demo:
|