Sky-OG-Model / indomain /build_frames_export.py
merdiofrivia's picture
Add frames export script
24624c7 verified
Raw
History Blame Contribute Delete
4.82 kB
#!/usr/bin/env python3
"""
Frame (zoom-out) export for the review tool.
Replays the EXACT best-view selection from build_dashboard_export.py (max screen-prob
view, sorted-frame tie-break) so that, for every flagged module, we know:
- which original frame its shown crop came from
- that module's 4-corner box in that frame's pixel space
Outputs into runs/indomain/dashboard_export/:
frames/{frame_stem}.png full 640x512 radiometric frame, 16-bit (degC*10), deduped
frames_index.csv module_id, frame_file, frame_w, frame_h, qx0..qy3
frames.tar tar of frames/ for upload
"""
import json, os, collections, csv, tarfile, time, warnings
warnings.filterwarnings("ignore")
import numpy as np, cv2, torch
import sys; sys.path.insert(0, "/workspace/antevolt/scripts")
from gate_h_measure import load_raster, core_mean, R_PX
from build_indomain_dataset import order_quad, PW, THERM, TEMP, LBL
from train_indomain import CNN, CLIP
ROOT = "/workspace/antevolt"; OUT = f"{ROOT}/runs/indomain"
EXP = f"{OUT}/dashboard_export"; FRAMES = f"{EXP}/frames"
os.makedirs(FRAMES, exist_ok=True)
DEV = "cuda:0" if torch.cuda.is_available() else "cpu"
THR = 0.5
flagged = set(int(r["module_id"]) for r in
csv.DictReader(open(f"{EXP}/review_candidates.csv")))
print(f"flagged modules: {len(flagged)}")
# per-frame quads (same source as the export)
frame_mods = collections.defaultdict(list)
for line in open(LBL):
r = json.loads(line); fr = r["frame"]
for m in r["modules"]:
frame_mods[fr].append((m["id"], m["quad"], np.array(m["cen"], float)))
frames = sorted(frame_mods)
screen = CNN(2).to(DEV)
screen.load_state_dict(torch.load(f"{OUT}/screen.pt", map_location=DEV)); screen.eval()
def warp_abs(temp, quad):
o = order_quad(quad)
dst = np.array([[0, 0], [PW - 1, 0], [PW - 1, PW - 1], [0, PW - 1]], np.float32)
M = cv2.getPerspectiveTransform(o, dst)
return cv2.warpPerspective(temp, M, (PW, PW), flags=cv2.INTER_LINEAR,
borderMode=cv2.BORDER_REPLICATE)
best_p = collections.defaultdict(float)
best_frame, best_quad = {}, {}
t0 = time.time()
for fi, fr in enumerate(frames):
try:
temp = load_raster(fr)
except Exception:
continue
mods = frame_mods[fr]
cms = [core_mean(temp, q) for _, q, _ in mods]
valid = [(cen, mu) for (mid, q, cen), mu in zip(mods, cms) if mu is not None]
if not valid:
continue
gref = float(np.median([mu for _, mu in valid]))
dts, ids, quads = [], [], []
for (mid, q, cen), mu in zip(mods, cms):
if mid not in flagged:
continue
loc = [m2 for c2, m2 in valid if np.hypot(*(c2 - cen)) <= R_PX]
ref = float(np.median(loc)) if len(loc) >= 3 else gref
ac = warp_abs(temp, q); dc = ac - ref
dts.append(dc); ids.append(mid); quads.append(q)
if not ids:
continue
Xb = np.clip(np.stack(dts), -CLIP, CLIP) / CLIP
with torch.no_grad():
xb = torch.tensor(Xb, dtype=torch.float32).unsqueeze(1).to(DEV)
sp = torch.softmax(screen(xb), 1)[:, 1].cpu().numpy()
for k, mid in enumerate(ids):
pr = float(sp[k])
if pr >= best_p[mid] or mid not in best_frame:
best_p[mid] = pr; best_frame[mid] = fr; best_quad[mid] = quads[k]
if fi % 500 == 0:
print(f" frame {fi}/{len(frames)} {time.time()-t0:.0f}s", flush=True)
print(f"modules with a best frame: {len(best_frame)}")
# ---- render the deduped set of best-view frames ----
need = sorted(set(best_frame.values()))
print(f"unique best-view frames to render: {len(need)}")
stem = lambda fr: os.path.splitext(fr)[0]
FW = FH = None
for j, fr in enumerate(need):
try:
temp = load_raster(fr)
except Exception:
continue
FH, FW = temp.shape
p16 = np.clip(np.round(temp * 10), 0, 65535).astype(np.uint16)
cv2.imwrite(f"{FRAMES}/{stem(fr)}.png", p16)
if j % 500 == 0:
print(f" render {j}/{len(need)}", flush=True)
# ---- index ----
with open(f"{EXP}/frames_index.csv", "w", newline="") as f:
w = csv.writer(f)
w.writerow(["module_id", "frame_file", "frame_w", "frame_h",
"qx0", "qy0", "qx1", "qy1", "qx2", "qy2", "qx3", "qy3"])
for mid in sorted(best_frame):
q = np.asarray(best_quad[mid], float).reshape(-1, 2)[:4]
w.writerow([mid, f"frames/{stem(best_frame[mid])}.png", FW, FH,
*[round(float(v), 1) for pt in q for v in pt]])
print("wrote frames_index.csv")
# ---- tar ----
tarp = f"{EXP}/frames.tar"
with tarfile.open(tarp, "w") as tf:
tf.add(FRAMES, arcname="frames")
dsz = sum(os.path.getsize(os.path.join(FRAMES, x)) for x in os.listdir(FRAMES))
print(f"\nframes={len(need)} frames_dir={dsz/1e6:.1f}MB "
f"frames.tar={os.path.getsize(tarp)/1e6:.1f}MB")