Add files using upload-large-folder tool
Browse files- analyze_d2_csv.py +239 -0
- cache_q_features.py +125 -0
- d2_basic.py +340 -0
- d2_llm_space.py +314 -0
- decoder_invariance_check.py +256 -0
- load_model.py +51 -672
- save_audio_feats.py +0 -1
- setup_simtoken.md +112 -45
- simtoken_experiment.md +369 -0
- target_frame_sweep.py +265 -0
- train.py +164 -7
- train_cached_gate.py +439 -0
analyze_d2_csv.py
ADDED
|
@@ -0,0 +1,239 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import csv
|
| 3 |
+
import math
|
| 4 |
+
from collections import defaultdict
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def parse_args():
|
| 10 |
+
parser = argparse.ArgumentParser(description="Analyze D2 frame-level CSV.")
|
| 11 |
+
parser.add_argument("--csv", required=True, help="Path to d2_llm_space.py or d2_basic.py CSV output.")
|
| 12 |
+
parser.add_argument("--beta", type=float, default=1.0)
|
| 13 |
+
parser.add_argument("--failure_iou", type=float, default=0.5)
|
| 14 |
+
parser.add_argument("--bottom_frac", type=float, default=0.2)
|
| 15 |
+
parser.add_argument("--pr_points", type=int, default=10)
|
| 16 |
+
return parser.parse_args()
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def read_rows(path, beta):
|
| 20 |
+
rows = []
|
| 21 |
+
with open(path, newline="") as f:
|
| 22 |
+
reader = csv.DictReader(f)
|
| 23 |
+
for row in reader:
|
| 24 |
+
row_beta = float(row["beta"])
|
| 25 |
+
if abs(row_beta - beta) > 1e-8:
|
| 26 |
+
continue
|
| 27 |
+
q_col = "h_type" if "h_type" in row else "q_type"
|
| 28 |
+
rows.append(
|
| 29 |
+
{
|
| 30 |
+
"sample_idx": int(row["sample_idx"]),
|
| 31 |
+
"frame": int(row["frame"]),
|
| 32 |
+
"anchor_type": row[q_col],
|
| 33 |
+
"s_pred": float(row["s_pred"]),
|
| 34 |
+
"s_gt": float(row["s_gt"]),
|
| 35 |
+
"frame_iou": float(row["frame_iou"]),
|
| 36 |
+
"iou_pred": float(row["iou_pred"]),
|
| 37 |
+
"pred_area": float(row["pred_area"]),
|
| 38 |
+
"gt_area": float(row["gt_area"]),
|
| 39 |
+
}
|
| 40 |
+
)
|
| 41 |
+
if not rows:
|
| 42 |
+
raise RuntimeError(f"No rows found for beta={beta} in {path}")
|
| 43 |
+
return rows
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def corr(x, y):
|
| 47 |
+
x = np.asarray(x, dtype=np.float64)
|
| 48 |
+
y = np.asarray(y, dtype=np.float64)
|
| 49 |
+
if len(x) < 2 or np.std(x) < 1e-12 or np.std(y) < 1e-12:
|
| 50 |
+
return float("nan")
|
| 51 |
+
return float(np.corrcoef(x, y)[0, 1])
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def residualize(y, controls):
|
| 55 |
+
y = np.asarray(y, dtype=np.float64)
|
| 56 |
+
cols = [np.ones(len(y), dtype=np.float64)]
|
| 57 |
+
for control in controls:
|
| 58 |
+
cols.append(np.asarray(control, dtype=np.float64))
|
| 59 |
+
x = np.stack(cols, axis=1)
|
| 60 |
+
coef, *_ = np.linalg.lstsq(x, y, rcond=None)
|
| 61 |
+
return y - x @ coef
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def r2_score(y, y_pred):
|
| 65 |
+
y = np.asarray(y, dtype=np.float64)
|
| 66 |
+
y_pred = np.asarray(y_pred, dtype=np.float64)
|
| 67 |
+
ss_res = np.sum((y - y_pred) ** 2)
|
| 68 |
+
ss_tot = np.sum((y - y.mean()) ** 2)
|
| 69 |
+
if ss_tot < 1e-12:
|
| 70 |
+
return float("nan")
|
| 71 |
+
return float(1.0 - ss_res / ss_tot)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def linear_r2(y, features):
|
| 75 |
+
y = np.asarray(y, dtype=np.float64)
|
| 76 |
+
cols = [np.ones(len(y), dtype=np.float64)]
|
| 77 |
+
for feature in features:
|
| 78 |
+
cols.append(np.asarray(feature, dtype=np.float64))
|
| 79 |
+
x = np.stack(cols, axis=1)
|
| 80 |
+
coef, *_ = np.linalg.lstsq(x, y, rcond=None)
|
| 81 |
+
return r2_score(y, x @ coef)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def real_rows(rows):
|
| 85 |
+
return [r for r in rows if r["anchor_type"] == "real"]
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def bottom_failure_enrichment(rows, failure_iou, bottom_frac):
|
| 89 |
+
rr = real_rows(rows)
|
| 90 |
+
n = len(rr)
|
| 91 |
+
k = max(1, int(round(n * bottom_frac)))
|
| 92 |
+
sorted_rows = sorted(rr, key=lambda r: r["s_pred"])
|
| 93 |
+
bottom = sorted_rows[:k]
|
| 94 |
+
baseline_rate = np.mean([r["frame_iou"] < failure_iou for r in rr])
|
| 95 |
+
bottom_rate = np.mean([r["frame_iou"] < failure_iou for r in bottom])
|
| 96 |
+
total_failures = sum(r["frame_iou"] < failure_iou for r in rr)
|
| 97 |
+
covered_failures = sum(r["frame_iou"] < failure_iou for r in bottom)
|
| 98 |
+
recall = covered_failures / max(total_failures, 1)
|
| 99 |
+
enrichment = bottom_rate / max(baseline_rate, 1e-12)
|
| 100 |
+
return {
|
| 101 |
+
"n": n,
|
| 102 |
+
"k": k,
|
| 103 |
+
"baseline_failure_rate": baseline_rate,
|
| 104 |
+
"bottom_failure_rate": bottom_rate,
|
| 105 |
+
"bottom_failure_recall": recall,
|
| 106 |
+
"enrichment": enrichment,
|
| 107 |
+
"total_failures": total_failures,
|
| 108 |
+
}
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def pr_curve(rows, failure_iou, points):
|
| 112 |
+
rr = sorted(real_rows(rows), key=lambda r: r["s_pred"])
|
| 113 |
+
total_failures = sum(r["frame_iou"] < failure_iou for r in rr)
|
| 114 |
+
out = []
|
| 115 |
+
for frac in np.linspace(0.05, 1.0, points):
|
| 116 |
+
k = max(1, int(round(len(rr) * frac)))
|
| 117 |
+
selected = rr[:k]
|
| 118 |
+
failures = sum(r["frame_iou"] < failure_iou for r in selected)
|
| 119 |
+
precision = failures / k
|
| 120 |
+
recall = failures / max(total_failures, 1)
|
| 121 |
+
out.append((frac, precision, recall))
|
| 122 |
+
return out
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def margin_rows(rows):
|
| 126 |
+
grouped = defaultdict(dict)
|
| 127 |
+
for r in rows:
|
| 128 |
+
key = (r["sample_idx"], r["frame"])
|
| 129 |
+
grouped[key][r["anchor_type"]] = r
|
| 130 |
+
|
| 131 |
+
out = []
|
| 132 |
+
for key, group in grouped.items():
|
| 133 |
+
if "real" not in group:
|
| 134 |
+
continue
|
| 135 |
+
controls = [group[name]["s_pred"] for name in ("shuffled", "wrong_ref") if name in group]
|
| 136 |
+
if not controls:
|
| 137 |
+
continue
|
| 138 |
+
real = group["real"]
|
| 139 |
+
item = dict(real)
|
| 140 |
+
item["s_margin"] = real["s_pred"] - max(controls)
|
| 141 |
+
out.append(item)
|
| 142 |
+
return out
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def bottom_failure_enrichment_for_score(rows, score_key, failure_iou, bottom_frac):
|
| 146 |
+
n = len(rows)
|
| 147 |
+
k = max(1, int(round(n * bottom_frac)))
|
| 148 |
+
sorted_rows = sorted(rows, key=lambda r: r[score_key])
|
| 149 |
+
bottom = sorted_rows[:k]
|
| 150 |
+
baseline_rate = np.mean([r["frame_iou"] < failure_iou for r in rows])
|
| 151 |
+
bottom_rate = np.mean([r["frame_iou"] < failure_iou for r in bottom])
|
| 152 |
+
total_failures = sum(r["frame_iou"] < failure_iou for r in rows)
|
| 153 |
+
covered_failures = sum(r["frame_iou"] < failure_iou for r in bottom)
|
| 154 |
+
return {
|
| 155 |
+
"n": n,
|
| 156 |
+
"k": k,
|
| 157 |
+
"baseline_failure_rate": baseline_rate,
|
| 158 |
+
"bottom_failure_rate": bottom_rate,
|
| 159 |
+
"bottom_failure_recall": covered_failures / max(total_failures, 1),
|
| 160 |
+
"enrichment": bottom_rate / max(baseline_rate, 1e-12),
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def main():
|
| 165 |
+
args = parse_args()
|
| 166 |
+
rows = read_rows(args.csv, args.beta)
|
| 167 |
+
rr = real_rows(rows)
|
| 168 |
+
|
| 169 |
+
print(f"CSV: {args.csv}")
|
| 170 |
+
print(f"beta: {args.beta}")
|
| 171 |
+
print(f"real frames: {len(rr)}")
|
| 172 |
+
print(f"failure definition: frame_iou < {args.failure_iou}")
|
| 173 |
+
|
| 174 |
+
print("\nReal s_pred Correlations")
|
| 175 |
+
print(f"corr(s_pred, frame_iou): {corr([r['s_pred'] for r in rr], [r['frame_iou'] for r in rr]):+.4f}")
|
| 176 |
+
print(f"corr(s_pred, iou_pred): {corr([r['s_pred'] for r in rr], [r['iou_pred'] for r in rr]):+.4f}")
|
| 177 |
+
print(f"corr(s_pred, pred_area): {corr([r['s_pred'] for r in rr], [r['pred_area'] for r in rr]):+.4f}")
|
| 178 |
+
|
| 179 |
+
s_pred_values = [r["s_pred"] for r in rr]
|
| 180 |
+
frame_iou_values = [r["frame_iou"] for r in rr]
|
| 181 |
+
iou_pred_values = [r["iou_pred"] for r in rr]
|
| 182 |
+
pred_area_values = [r["pred_area"] for r in rr]
|
| 183 |
+
gt_area_values = [r["gt_area"] for r in rr]
|
| 184 |
+
partial_iou_pred = corr(
|
| 185 |
+
residualize(s_pred_values, [iou_pred_values]),
|
| 186 |
+
residualize(frame_iou_values, [iou_pred_values]),
|
| 187 |
+
)
|
| 188 |
+
partial_iou_area = corr(
|
| 189 |
+
residualize(s_pred_values, [iou_pred_values, pred_area_values]),
|
| 190 |
+
residualize(frame_iou_values, [iou_pred_values, pred_area_values]),
|
| 191 |
+
)
|
| 192 |
+
partial_iou_area_gt = corr(
|
| 193 |
+
residualize(s_pred_values, [iou_pred_values, pred_area_values, gt_area_values]),
|
| 194 |
+
residualize(frame_iou_values, [iou_pred_values, pred_area_values, gt_area_values]),
|
| 195 |
+
)
|
| 196 |
+
r2_iou_pred = linear_r2(frame_iou_values, [iou_pred_values])
|
| 197 |
+
r2_iou_pred_s = linear_r2(frame_iou_values, [iou_pred_values, s_pred_values])
|
| 198 |
+
r2_iou_pred_area = linear_r2(frame_iou_values, [iou_pred_values, pred_area_values])
|
| 199 |
+
r2_iou_pred_area_s = linear_r2(frame_iou_values, [iou_pred_values, pred_area_values, s_pred_values])
|
| 200 |
+
|
| 201 |
+
print("\nPartial Correlation / Residual Gain")
|
| 202 |
+
print(f"partial corr(s_pred, frame_iou | iou_pred): {partial_iou_pred:+.4f}")
|
| 203 |
+
print(f"partial corr(s_pred, frame_iou | iou_pred,pred_area): {partial_iou_area:+.4f}")
|
| 204 |
+
print(f"partial corr(s_pred, frame_iou | iou_pred,pred_area,gt_area): {partial_iou_area_gt:+.4f}")
|
| 205 |
+
print(f"R2 frame_iou ~ iou_pred: {r2_iou_pred:.4f}")
|
| 206 |
+
print(f"R2 frame_iou ~ iou_pred + s_pred: {r2_iou_pred_s:.4f} (gain {r2_iou_pred_s - r2_iou_pred:+.4f})")
|
| 207 |
+
print(f"R2 frame_iou ~ iou_pred + pred_area: {r2_iou_pred_area:.4f}")
|
| 208 |
+
print(f"R2 frame_iou ~ iou_pred + pred_area + s_pred: {r2_iou_pred_area_s:.4f} (gain {r2_iou_pred_area_s - r2_iou_pred_area:+.4f})")
|
| 209 |
+
|
| 210 |
+
stats = bottom_failure_enrichment(rows, args.failure_iou, args.bottom_frac)
|
| 211 |
+
print("\nBottom-k Failure Enrichment")
|
| 212 |
+
print(f"bottom_frac: {args.bottom_frac:.2f} ({stats['k']}/{stats['n']} frames)")
|
| 213 |
+
print(f"total failures: {stats['total_failures']}")
|
| 214 |
+
print(f"random/baseline failure rate: {stats['baseline_failure_rate']:.4f}")
|
| 215 |
+
print(f"bottom-s_pred failure rate: {stats['bottom_failure_rate']:.4f}")
|
| 216 |
+
print(f"bottom-s_pred failure recall: {stats['bottom_failure_recall']:.4f}")
|
| 217 |
+
print(f"enrichment: {stats['enrichment']:.2f}x")
|
| 218 |
+
|
| 219 |
+
print("\nPR Curve Summary")
|
| 220 |
+
print("selected_frac | precision | recall")
|
| 221 |
+
for frac, precision, recall in pr_curve(rows, args.failure_iou, args.pr_points):
|
| 222 |
+
print(f"{frac:.2f} | {precision:.4f} | {recall:.4f}")
|
| 223 |
+
|
| 224 |
+
mr = margin_rows(rows)
|
| 225 |
+
if mr:
|
| 226 |
+
print("\nOffline Margin-D2")
|
| 227 |
+
print(f"margin frames: {len(mr)}")
|
| 228 |
+
print(f"corr(s_margin, frame_iou): {corr([r['s_margin'] for r in mr], [r['frame_iou'] for r in mr]):+.4f}")
|
| 229 |
+
print(f"corr(s_margin, pred_area): {corr([r['s_margin'] for r in mr], [r['pred_area'] for r in mr]):+.4f}")
|
| 230 |
+
mstats = bottom_failure_enrichment_for_score(mr, "s_margin", args.failure_iou, args.bottom_frac)
|
| 231 |
+
print(f"bottom-s_margin failure rate: {mstats['bottom_failure_rate']:.4f}")
|
| 232 |
+
print(f"bottom-s_margin failure recall: {mstats['bottom_failure_recall']:.4f}")
|
| 233 |
+
print(f"margin enrichment: {mstats['enrichment']:.2f}x")
|
| 234 |
+
else:
|
| 235 |
+
print("\nOffline Margin-D2 skipped: shuffled/wrong_ref controls not available.")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
if __name__ == "__main__":
|
| 239 |
+
main()
|
cache_q_features.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
from functools import partial
|
| 4 |
+
from itertools import islice
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import transformers
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
|
| 11 |
+
from configs import args
|
| 12 |
+
from datasets import REFAVS
|
| 13 |
+
from decoder_invariance_check import build_model, set_seed
|
| 14 |
+
from load_model import collate_fn, dict_to_cuda
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _jsonable_size(size):
|
| 18 |
+
if isinstance(size, torch.Tensor):
|
| 19 |
+
return [int(x) for x in size.detach().cpu().tolist()]
|
| 20 |
+
return [int(x) for x in size]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def main():
|
| 24 |
+
set_seed(42)
|
| 25 |
+
torch.set_grad_enabled(False)
|
| 26 |
+
|
| 27 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 28 |
+
args.mllm,
|
| 29 |
+
cache_dir=None,
|
| 30 |
+
model_max_length=2048,
|
| 31 |
+
padding_side="right",
|
| 32 |
+
use_fast=False,
|
| 33 |
+
)
|
| 34 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 35 |
+
tokenizer.add_tokens("[SEG]")
|
| 36 |
+
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 37 |
+
|
| 38 |
+
dataset = REFAVS(args.cache_split, args, tokenizer, input_type="refer")
|
| 39 |
+
loader = DataLoader(
|
| 40 |
+
dataset,
|
| 41 |
+
batch_size=1,
|
| 42 |
+
shuffle=False,
|
| 43 |
+
num_workers=0,
|
| 44 |
+
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
split_root = os.path.join(args.cache_root, args.cache_split)
|
| 48 |
+
os.makedirs(split_root, exist_ok=True)
|
| 49 |
+
index_path = os.path.join(split_root, "index.jsonl")
|
| 50 |
+
if os.path.exists(index_path) and not args.overwrite_cache:
|
| 51 |
+
raise FileExistsError(
|
| 52 |
+
f"{index_path} already exists. Pass --overwrite_cache to rebuild it."
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
limit = args.max_eval_rows if args.max_eval_rows > 0 else len(dataset)
|
| 56 |
+
print(f"cache split={args.cache_split} | samples={min(limit, len(dataset))}")
|
| 57 |
+
print(f"cache root: {split_root}")
|
| 58 |
+
|
| 59 |
+
model = build_model(tokenizer, seg_token_idx)
|
| 60 |
+
model.eval()
|
| 61 |
+
|
| 62 |
+
rows = []
|
| 63 |
+
for sample_idx, batch in enumerate(
|
| 64 |
+
tqdm(islice(loader, limit), total=min(limit, len(dataset)), desc=f"Caching {args.cache_split}")
|
| 65 |
+
):
|
| 66 |
+
batch = dict_to_cuda(batch)
|
| 67 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 68 |
+
output = model.forward(
|
| 69 |
+
images=batch["images"],
|
| 70 |
+
images_clip=batch["images_clip"],
|
| 71 |
+
audio_features=batch["audio_feats"],
|
| 72 |
+
image_features=batch["image_feats"],
|
| 73 |
+
input_ids=batch["input_ids"],
|
| 74 |
+
labels=batch["labels"],
|
| 75 |
+
attention_masks=batch["attention_masks"],
|
| 76 |
+
masks_list=batch["masks"],
|
| 77 |
+
resize_list=batch["resizes"],
|
| 78 |
+
orgsize_list=batch["orgsizes"],
|
| 79 |
+
conversation_list=batch["convs"],
|
| 80 |
+
refs_num=batch["refs_num"],
|
| 81 |
+
fids=batch["fids"],
|
| 82 |
+
vids=batch["vids"],
|
| 83 |
+
contrast=args.ct_weight,
|
| 84 |
+
ref_ids=batch["ref_ids"],
|
| 85 |
+
inference=True,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
cache_name = f"{sample_idx:06d}.pt"
|
| 89 |
+
cache_path = os.path.join(split_root, cache_name)
|
| 90 |
+
item = {
|
| 91 |
+
"sample_idx": sample_idx,
|
| 92 |
+
"vid": batch["vids"][0],
|
| 93 |
+
"refs": batch["refs"][0],
|
| 94 |
+
"fids": [int(x) for x in batch["fids"][0]],
|
| 95 |
+
"resize": _jsonable_size(batch["resizes"][0]),
|
| 96 |
+
"orgsize": _jsonable_size(batch["orgsizes"][0]),
|
| 97 |
+
"q": output["seg_embeddings"][0].detach().cpu().float(),
|
| 98 |
+
}
|
| 99 |
+
torch.save(item, cache_path)
|
| 100 |
+
rows.append(
|
| 101 |
+
{
|
| 102 |
+
"sample_idx": sample_idx,
|
| 103 |
+
"path": cache_name,
|
| 104 |
+
"vid": item["vid"],
|
| 105 |
+
"refs": item["refs"],
|
| 106 |
+
"fids": item["fids"],
|
| 107 |
+
"resize": item["resize"],
|
| 108 |
+
"orgsize": item["orgsize"],
|
| 109 |
+
"num_seg": int(item["q"].shape[0]),
|
| 110 |
+
}
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
if not rows:
|
| 114 |
+
raise RuntimeError("No samples were cached.")
|
| 115 |
+
|
| 116 |
+
with open(index_path, "w") as f:
|
| 117 |
+
for row in rows:
|
| 118 |
+
f.write(json.dumps(row) + "\n")
|
| 119 |
+
|
| 120 |
+
print(f"cached samples: {len(rows)}")
|
| 121 |
+
print(f"saved index: {index_path}")
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
if __name__ == "__main__":
|
| 125 |
+
main()
|
d2_basic.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import transformers
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from configs import args
|
| 13 |
+
from datasets import REFAVS
|
| 14 |
+
from decoder_invariance_check import build_model, set_seed
|
| 15 |
+
from load_model import collate_fn, dict_to_cuda
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def make_loader(tokenizer):
|
| 19 |
+
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
|
| 20 |
+
return DataLoader(
|
| 21 |
+
dataset,
|
| 22 |
+
batch_size=1,
|
| 23 |
+
shuffle=False,
|
| 24 |
+
num_workers=0,
|
| 25 |
+
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def build_tokenizer():
|
| 30 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 31 |
+
args.mllm,
|
| 32 |
+
cache_dir=None,
|
| 33 |
+
model_max_length=2048,
|
| 34 |
+
padding_side="right",
|
| 35 |
+
use_fast=False,
|
| 36 |
+
)
|
| 37 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 38 |
+
tokenizer.add_tokens("[SEG]")
|
| 39 |
+
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 40 |
+
return tokenizer, seg_token_idx
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def get_q(model, batch):
|
| 44 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 45 |
+
output = model.forward(
|
| 46 |
+
images=batch["images"],
|
| 47 |
+
images_clip=batch["images_clip"],
|
| 48 |
+
audio_features=batch["audio_feats"],
|
| 49 |
+
image_features=batch["image_feats"],
|
| 50 |
+
input_ids=batch["input_ids"],
|
| 51 |
+
labels=batch["labels"],
|
| 52 |
+
attention_masks=batch["attention_masks"],
|
| 53 |
+
masks_list=batch["masks"],
|
| 54 |
+
resize_list=batch["resizes"],
|
| 55 |
+
orgsize_list=batch["orgsizes"],
|
| 56 |
+
conversation_list=batch["convs"],
|
| 57 |
+
refs_num=batch["refs_num"],
|
| 58 |
+
fids=batch["fids"],
|
| 59 |
+
vids=batch["vids"],
|
| 60 |
+
contrast=args.ct_weight,
|
| 61 |
+
ref_ids=batch["ref_ids"],
|
| 62 |
+
inference=True,
|
| 63 |
+
)
|
| 64 |
+
return output["seg_embeddings"][0][0].float()
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def decode_low_res(model, batch, q):
|
| 68 |
+
visual_model = model.get_model().visual_model
|
| 69 |
+
sparse, dense = visual_model.prompt_encoder(
|
| 70 |
+
points=None,
|
| 71 |
+
boxes=None,
|
| 72 |
+
masks=None,
|
| 73 |
+
text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
|
| 74 |
+
)
|
| 75 |
+
sparse = sparse.to(q.dtype)
|
| 76 |
+
dense = dense.to(q.dtype)
|
| 77 |
+
|
| 78 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 79 |
+
low_res_masks, iou_predictions = visual_model.mask_decoder(
|
| 80 |
+
image_embeddings=batch["image_feats"][0],
|
| 81 |
+
image_pe=visual_model.prompt_encoder.get_dense_pe(),
|
| 82 |
+
sparse_prompt_embeddings=sparse,
|
| 83 |
+
dense_prompt_embeddings=dense,
|
| 84 |
+
multimask_output=False,
|
| 85 |
+
)
|
| 86 |
+
return low_res_masks.float(), iou_predictions.float().squeeze(-1)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def masks_to_64(mask_logits_or_binary):
|
| 90 |
+
if mask_logits_or_binary.ndim == 3:
|
| 91 |
+
mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
|
| 92 |
+
return F.interpolate(
|
| 93 |
+
mask_logits_or_binary.float(),
|
| 94 |
+
size=(64, 64),
|
| 95 |
+
mode="bilinear",
|
| 96 |
+
align_corners=False,
|
| 97 |
+
).clamp(0.0, 1.0)
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def d2_scores(image_embeddings, mask64, q, beta):
|
| 101 |
+
feats = image_embeddings.float()
|
| 102 |
+
if mask64.shape[0] != feats.shape[0]:
|
| 103 |
+
raise ValueError(f"Mask/frame mismatch: {mask64.shape} vs {feats.shape}")
|
| 104 |
+
|
| 105 |
+
q = F.normalize(q.float().view(1, -1), dim=-1)
|
| 106 |
+
mask = mask64.float()
|
| 107 |
+
comp = 1.0 - mask
|
| 108 |
+
|
| 109 |
+
z_in = (feats * mask).sum(dim=(2, 3)) / mask.sum(dim=(2, 3)).clamp_min(1e-6)
|
| 110 |
+
z_out = (feats * comp).sum(dim=(2, 3)) / comp.sum(dim=(2, 3)).clamp_min(1e-6)
|
| 111 |
+
|
| 112 |
+
z_in = F.normalize(z_in, dim=-1)
|
| 113 |
+
z_out = F.normalize(z_out, dim=-1)
|
| 114 |
+
return (z_in @ q.T).squeeze(-1) - beta * (z_out @ q.T).squeeze(-1)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def frame_iou(pred_logits, gt_masks):
|
| 118 |
+
pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
|
| 119 |
+
gt = gt_masks.float()
|
| 120 |
+
if pred.ndim == 4:
|
| 121 |
+
pred = pred.squeeze(1)
|
| 122 |
+
inter = (pred * gt).sum(dim=(1, 2))
|
| 123 |
+
union = torch.maximum(pred, gt).sum(dim=(1, 2))
|
| 124 |
+
num_pixels = pred.shape[-1] * pred.shape[-2]
|
| 125 |
+
no_obj = gt.sum(dim=(1, 2)) == 0
|
| 126 |
+
inter_no_obj = ((1.0 - pred) * (1.0 - gt)).sum(dim=(1, 2))
|
| 127 |
+
inter = torch.where(no_obj, inter_no_obj, inter)
|
| 128 |
+
union = torch.where(no_obj, torch.full_like(union, float(num_pixels)), union)
|
| 129 |
+
return inter / union.clamp_min(1e-7)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def frame_fscore_proxy(pred_logits, gt_masks):
|
| 133 |
+
pred = (torch.sigmoid(pred_logits.float()) > 0.4).float()
|
| 134 |
+
gt = gt_masks.float()
|
| 135 |
+
if pred.ndim == 4:
|
| 136 |
+
pred = pred.squeeze(1)
|
| 137 |
+
tp = (pred * gt).sum(dim=(1, 2))
|
| 138 |
+
precision = tp / pred.sum(dim=(1, 2)).clamp_min(1e-7)
|
| 139 |
+
recall = tp / gt.sum(dim=(1, 2)).clamp_min(1e-7)
|
| 140 |
+
beta2 = 0.3
|
| 141 |
+
fscore = (1 + beta2) * precision * recall / (beta2 * precision + recall).clamp_min(1e-7)
|
| 142 |
+
no_obj = gt.sum(dim=(1, 2)) == 0
|
| 143 |
+
return torch.where(no_obj, torch.zeros_like(fscore), fscore)
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def parse_betas():
|
| 147 |
+
raw = os.environ.get("D2_BETAS", "0.5")
|
| 148 |
+
return [float(x.strip()) for x in raw.split(",") if x.strip()]
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def collect_q_pool(model, tokenizer, limit):
|
| 152 |
+
q_pool = []
|
| 153 |
+
loader = make_loader(tokenizer)
|
| 154 |
+
for sample_idx, batch in enumerate(loader):
|
| 155 |
+
if sample_idx >= limit:
|
| 156 |
+
break
|
| 157 |
+
batch = dict_to_cuda(batch)
|
| 158 |
+
q = get_q(model, batch)
|
| 159 |
+
q_pool.append(
|
| 160 |
+
{
|
| 161 |
+
"sample_idx": sample_idx,
|
| 162 |
+
"vid": batch["vids"][0],
|
| 163 |
+
"ref": batch["refs"][0][0],
|
| 164 |
+
"fid": int(batch["fids"][0][0]),
|
| 165 |
+
"q": q.cpu(),
|
| 166 |
+
}
|
| 167 |
+
)
|
| 168 |
+
print(f"Collected q {sample_idx}: vid={q_pool[-1]['vid']} ref={q_pool[-1]['ref']}")
|
| 169 |
+
if not q_pool:
|
| 170 |
+
raise RuntimeError("No q vectors collected. Is the selected split empty?")
|
| 171 |
+
return q_pool
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def choose_shuffled_idx(sample_idx, q_pool):
|
| 175 |
+
if len(q_pool) <= 1:
|
| 176 |
+
return None
|
| 177 |
+
return (sample_idx + 1) % len(q_pool)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
def choose_wrong_ref_idx(sample_idx, q_pool):
|
| 181 |
+
current = q_pool[sample_idx]
|
| 182 |
+
for item in q_pool:
|
| 183 |
+
if item["sample_idx"] == sample_idx:
|
| 184 |
+
continue
|
| 185 |
+
if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
|
| 186 |
+
return item["sample_idx"]
|
| 187 |
+
for item in q_pool:
|
| 188 |
+
if item["sample_idx"] == sample_idx:
|
| 189 |
+
continue
|
| 190 |
+
if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
|
| 191 |
+
return item["sample_idx"]
|
| 192 |
+
return None
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def run_d2(model, tokenizer, q_pool, betas, limit):
|
| 196 |
+
rows = []
|
| 197 |
+
loader = make_loader(tokenizer)
|
| 198 |
+
q_lookup = {item["sample_idx"]: item for item in q_pool}
|
| 199 |
+
generator = torch.Generator(device="cuda")
|
| 200 |
+
generator.manual_seed(1234)
|
| 201 |
+
|
| 202 |
+
for sample_idx, batch in enumerate(loader):
|
| 203 |
+
if sample_idx >= limit:
|
| 204 |
+
break
|
| 205 |
+
batch = dict_to_cuda(batch)
|
| 206 |
+
item = q_lookup[sample_idx]
|
| 207 |
+
real_q = item["q"].cuda()
|
| 208 |
+
|
| 209 |
+
low_res_masks, iou_predictions = decode_low_res(model, batch, real_q)
|
| 210 |
+
pred_mask64 = masks_to_64(torch.sigmoid(low_res_masks))
|
| 211 |
+
gt_masks = batch["masks"][0][0].float()
|
| 212 |
+
gt_mask64 = masks_to_64(gt_masks)
|
| 213 |
+
image_embeddings = batch["image_feats"][0].float()
|
| 214 |
+
|
| 215 |
+
pred_logits_hr = model.get_model().visual_model.postprocess_masks(
|
| 216 |
+
low_res_masks.to(batch["image_feats"][0].dtype),
|
| 217 |
+
input_size=batch["resizes"][0],
|
| 218 |
+
original_size=batch["orgsizes"][0],
|
| 219 |
+
).squeeze(1)
|
| 220 |
+
|
| 221 |
+
frame_ious = frame_iou(pred_logits_hr, gt_masks)
|
| 222 |
+
frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
|
| 223 |
+
pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
|
| 224 |
+
gt_area = gt_masks.float().mean(dim=(1, 2))
|
| 225 |
+
|
| 226 |
+
shuffled_idx = choose_shuffled_idx(sample_idx, q_pool)
|
| 227 |
+
wrong_ref_idx = choose_wrong_ref_idx(sample_idx, q_pool)
|
| 228 |
+
q_controls = [
|
| 229 |
+
("real", real_q, sample_idx),
|
| 230 |
+
("random", torch.randn(real_q.shape, device=real_q.device, generator=generator), None),
|
| 231 |
+
]
|
| 232 |
+
if shuffled_idx is not None:
|
| 233 |
+
q_controls.append(("shuffled", q_lookup[shuffled_idx]["q"].cuda(), shuffled_idx))
|
| 234 |
+
if wrong_ref_idx is not None:
|
| 235 |
+
q_controls.append(("wrong_ref", q_lookup[wrong_ref_idx]["q"].cuda(), wrong_ref_idx))
|
| 236 |
+
|
| 237 |
+
for beta in betas:
|
| 238 |
+
for q_type, q, q_source_idx in q_controls:
|
| 239 |
+
pred_scores = d2_scores(image_embeddings, pred_mask64, q, beta)
|
| 240 |
+
gt_scores = d2_scores(image_embeddings, gt_mask64, q, beta)
|
| 241 |
+
base_info = {
|
| 242 |
+
"sample_idx": sample_idx,
|
| 243 |
+
"vid": item["vid"],
|
| 244 |
+
"ref": item["ref"],
|
| 245 |
+
"fid": item["fid"],
|
| 246 |
+
"split": args.eval_split,
|
| 247 |
+
"frame_iou": math.nan,
|
| 248 |
+
"frame_fscore_proxy": math.nan,
|
| 249 |
+
"iou_pred": math.nan,
|
| 250 |
+
"pred_area": math.nan,
|
| 251 |
+
"gt_area": math.nan,
|
| 252 |
+
}
|
| 253 |
+
for frame_idx in range(pred_scores.shape[0]):
|
| 254 |
+
base_info_frame = dict(base_info)
|
| 255 |
+
base_info_frame.update(
|
| 256 |
+
{
|
| 257 |
+
"frame_iou": frame_ious[frame_idx].item(),
|
| 258 |
+
"frame_fscore_proxy": frame_fscores[frame_idx].item(),
|
| 259 |
+
"iou_pred": iou_predictions[frame_idx].item(),
|
| 260 |
+
"pred_area": pred_area[frame_idx].item(),
|
| 261 |
+
"gt_area": gt_area[frame_idx].item(),
|
| 262 |
+
}
|
| 263 |
+
)
|
| 264 |
+
row = dict(base_info_frame)
|
| 265 |
+
row.update(
|
| 266 |
+
{
|
| 267 |
+
"frame": frame_idx,
|
| 268 |
+
"q_type": q_type,
|
| 269 |
+
"beta": beta,
|
| 270 |
+
"s_pred": pred_scores[frame_idx].item(),
|
| 271 |
+
"s_gt": gt_scores[frame_idx].item(),
|
| 272 |
+
"q_source_idx": q_source_idx if q_source_idx is not None else "",
|
| 273 |
+
}
|
| 274 |
+
)
|
| 275 |
+
rows.append(row)
|
| 276 |
+
|
| 277 |
+
real_rows = [
|
| 278 |
+
r for r in rows if r["sample_idx"] == sample_idx and r["q_type"] == "real" and r["beta"] == betas[0]
|
| 279 |
+
]
|
| 280 |
+
s_pred_values = [r["s_pred"] for r in real_rows]
|
| 281 |
+
print(
|
| 282 |
+
f"D2 {sample_idx}: vid={item['vid']} ref={item['ref']} "
|
| 283 |
+
f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
|
| 284 |
+
f"mean_iou={frame_ious.mean().item():.4f}"
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
return rows
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def print_summary(rows):
|
| 291 |
+
real_rows = [r for r in rows if r["q_type"] == "real"]
|
| 292 |
+
if not real_rows:
|
| 293 |
+
return
|
| 294 |
+
by_beta = sorted(set(r["beta"] for r in real_rows))
|
| 295 |
+
print("\nSummary")
|
| 296 |
+
print(f"rows: {len(rows)}")
|
| 297 |
+
for beta in by_beta:
|
| 298 |
+
beta_rows = [r for r in rows if r["beta"] == beta]
|
| 299 |
+
print(f"\nbeta={beta}")
|
| 300 |
+
for q_type in sorted(set(r["q_type"] for r in beta_rows)):
|
| 301 |
+
qr = [r for r in beta_rows if r["q_type"] == q_type]
|
| 302 |
+
print(
|
| 303 |
+
f"{q_type:10s} "
|
| 304 |
+
f"mean_s_pred={np.mean([r['s_pred'] for r in qr]):+.4f} "
|
| 305 |
+
f"mean_s_gt={np.mean([r['s_gt'] for r in qr]):+.4f}"
|
| 306 |
+
)
|
| 307 |
+
real_beta = [r for r in beta_rows if r["q_type"] == "real"]
|
| 308 |
+
s_pred = np.array([r["s_pred"] for r in real_beta])
|
| 309 |
+
frame_iou_values = np.array([r["frame_iou"] for r in real_beta])
|
| 310 |
+
if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
|
| 311 |
+
corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
|
| 312 |
+
print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
|
| 313 |
+
else:
|
| 314 |
+
print("corr(real s_pred, frame_iou)=nan")
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def main():
|
| 318 |
+
set_seed(42)
|
| 319 |
+
torch.set_grad_enabled(False)
|
| 320 |
+
betas = parse_betas()
|
| 321 |
+
tokenizer, seg_token_idx = build_tokenizer()
|
| 322 |
+
limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
|
| 323 |
+
print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
|
| 324 |
+
|
| 325 |
+
model = build_model(tokenizer, seg_token_idx)
|
| 326 |
+
q_pool = collect_q_pool(model, tokenizer, limit)
|
| 327 |
+
rows = run_d2(model, tokenizer, q_pool, betas, limit)
|
| 328 |
+
print_summary(rows)
|
| 329 |
+
|
| 330 |
+
csv_path = os.environ.get("D2_BASIC_CSV", f"/workspace/SimToken/d2_basic_{args.eval_split}_{limit}.csv")
|
| 331 |
+
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
|
| 332 |
+
with open(csv_path, "w", newline="") as f:
|
| 333 |
+
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
| 334 |
+
writer.writeheader()
|
| 335 |
+
writer.writerows(rows)
|
| 336 |
+
print(f"\nSaved CSV: {csv_path}")
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
if __name__ == "__main__":
|
| 340 |
+
main()
|
d2_llm_space.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import transformers
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from configs import args
|
| 13 |
+
from datasets import REFAVS
|
| 14 |
+
from decoder_invariance_check import build_model, set_seed
|
| 15 |
+
from d2_basic import frame_fscore_proxy, frame_iou
|
| 16 |
+
from load_model import collate_fn, dict_to_cuda
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def build_tokenizer():
|
| 20 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 21 |
+
args.mllm,
|
| 22 |
+
cache_dir=None,
|
| 23 |
+
model_max_length=2048,
|
| 24 |
+
padding_side="right",
|
| 25 |
+
use_fast=False,
|
| 26 |
+
)
|
| 27 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 28 |
+
tokenizer.add_tokens("[SEG]")
|
| 29 |
+
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 30 |
+
return tokenizer, seg_token_idx
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def make_loader(tokenizer):
|
| 34 |
+
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
|
| 35 |
+
return DataLoader(
|
| 36 |
+
dataset,
|
| 37 |
+
batch_size=1,
|
| 38 |
+
shuffle=False,
|
| 39 |
+
num_workers=0,
|
| 40 |
+
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def forward_for_hidden_and_q(model, batch):
|
| 45 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 46 |
+
output = model.forward(
|
| 47 |
+
images=batch["images"],
|
| 48 |
+
images_clip=batch["images_clip"],
|
| 49 |
+
audio_features=batch["audio_feats"],
|
| 50 |
+
image_features=batch["image_feats"],
|
| 51 |
+
input_ids=batch["input_ids"],
|
| 52 |
+
labels=batch["labels"],
|
| 53 |
+
attention_masks=batch["attention_masks"],
|
| 54 |
+
masks_list=batch["masks"],
|
| 55 |
+
resize_list=batch["resizes"],
|
| 56 |
+
orgsize_list=batch["orgsizes"],
|
| 57 |
+
conversation_list=batch["convs"],
|
| 58 |
+
refs_num=batch["refs_num"],
|
| 59 |
+
fids=batch["fids"],
|
| 60 |
+
vids=batch["vids"],
|
| 61 |
+
contrast=args.ct_weight,
|
| 62 |
+
ref_ids=batch["ref_ids"],
|
| 63 |
+
inference=True,
|
| 64 |
+
)
|
| 65 |
+
h_seg = output["seg_hidden_states"][0][0].float()
|
| 66 |
+
q = output["seg_embeddings"][0][0].float()
|
| 67 |
+
return h_seg, q
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def decode_low_res(model, batch, q):
|
| 71 |
+
visual_model = model.get_model().visual_model
|
| 72 |
+
sparse, dense = visual_model.prompt_encoder(
|
| 73 |
+
points=None,
|
| 74 |
+
boxes=None,
|
| 75 |
+
masks=None,
|
| 76 |
+
text_embeds=q.view(1, 1, -1).to(next(visual_model.parameters()).dtype),
|
| 77 |
+
)
|
| 78 |
+
sparse = sparse.to(q.dtype)
|
| 79 |
+
dense = dense.to(q.dtype)
|
| 80 |
+
|
| 81 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 82 |
+
low_res_masks, iou_predictions = visual_model.mask_decoder(
|
| 83 |
+
image_embeddings=batch["image_feats"][0],
|
| 84 |
+
image_pe=visual_model.prompt_encoder.get_dense_pe(),
|
| 85 |
+
sparse_prompt_embeddings=sparse,
|
| 86 |
+
dense_prompt_embeddings=dense,
|
| 87 |
+
multimask_output=False,
|
| 88 |
+
)
|
| 89 |
+
return low_res_masks.float(), iou_predictions.float().squeeze(-1)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def clip_projected_tokens(model, batch):
|
| 93 |
+
images = torch.cat(batch["images_clip"], dim=0)
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
clip_tokens = model.encode_images(images)
|
| 96 |
+
projector = model.get_model().mm_projector
|
| 97 |
+
clip_tokens = clip_tokens.to(projector.weight.dtype)
|
| 98 |
+
llm_tokens = projector(clip_tokens).float()
|
| 99 |
+
return llm_tokens
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def infer_square_grid(num_tokens):
|
| 103 |
+
grid = int(math.sqrt(num_tokens))
|
| 104 |
+
if grid * grid != num_tokens:
|
| 105 |
+
raise ValueError(f"Expected square patch-token grid, got {num_tokens} tokens")
|
| 106 |
+
return grid
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def masks_to_token_grid(mask_logits_or_binary, num_tokens):
|
| 110 |
+
if mask_logits_or_binary.ndim == 3:
|
| 111 |
+
mask_logits_or_binary = mask_logits_or_binary.unsqueeze(1)
|
| 112 |
+
grid = infer_square_grid(num_tokens)
|
| 113 |
+
return F.interpolate(
|
| 114 |
+
mask_logits_or_binary.float(),
|
| 115 |
+
size=(grid, grid),
|
| 116 |
+
mode="bilinear",
|
| 117 |
+
align_corners=False,
|
| 118 |
+
).flatten(2).transpose(1, 2).clamp(0.0, 1.0)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def d2_scores_llm(llm_tokens, mask_tokens, h_seg, beta):
|
| 122 |
+
if llm_tokens.shape[:2] != mask_tokens.shape[:2]:
|
| 123 |
+
raise ValueError(f"Token/mask mismatch: {llm_tokens.shape} vs {mask_tokens.shape}")
|
| 124 |
+
h = F.normalize(h_seg.float().view(1, -1), dim=-1)
|
| 125 |
+
tokens = llm_tokens.float()
|
| 126 |
+
mask = mask_tokens.float()
|
| 127 |
+
comp = 1.0 - mask
|
| 128 |
+
|
| 129 |
+
z_in = (tokens * mask).sum(dim=1) / mask.sum(dim=1).clamp_min(1e-6)
|
| 130 |
+
z_out = (tokens * comp).sum(dim=1) / comp.sum(dim=1).clamp_min(1e-6)
|
| 131 |
+
|
| 132 |
+
z_in = F.normalize(z_in, dim=-1)
|
| 133 |
+
z_out = F.normalize(z_out, dim=-1)
|
| 134 |
+
return (z_in @ h.T).squeeze(-1) - beta * (z_out @ h.T).squeeze(-1)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def parse_betas():
|
| 138 |
+
raw = os.environ.get("D2_BETAS", "0.5")
|
| 139 |
+
return [float(x.strip()) for x in raw.split(",") if x.strip()]
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def collect_hidden_pool(model, tokenizer, limit):
|
| 143 |
+
pool = []
|
| 144 |
+
loader = make_loader(tokenizer)
|
| 145 |
+
for sample_idx, batch in enumerate(loader):
|
| 146 |
+
if sample_idx >= limit:
|
| 147 |
+
break
|
| 148 |
+
batch = dict_to_cuda(batch)
|
| 149 |
+
h_seg, q = forward_for_hidden_and_q(model, batch)
|
| 150 |
+
pool.append(
|
| 151 |
+
{
|
| 152 |
+
"sample_idx": sample_idx,
|
| 153 |
+
"vid": batch["vids"][0],
|
| 154 |
+
"ref": batch["refs"][0][0],
|
| 155 |
+
"fid": int(batch["fids"][0][0]),
|
| 156 |
+
"h": h_seg.cpu(),
|
| 157 |
+
"q": q.cpu(),
|
| 158 |
+
}
|
| 159 |
+
)
|
| 160 |
+
print(f"Collected h {sample_idx}: vid={pool[-1]['vid']} ref={pool[-1]['ref']}")
|
| 161 |
+
if not pool:
|
| 162 |
+
raise RuntimeError("No hidden states collected. Is the selected split empty?")
|
| 163 |
+
return pool
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def choose_shuffled_idx(sample_idx, pool):
|
| 167 |
+
if len(pool) <= 1:
|
| 168 |
+
return None
|
| 169 |
+
return (sample_idx + 1) % len(pool)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def choose_wrong_ref_idx(sample_idx, pool):
|
| 173 |
+
current = pool[sample_idx]
|
| 174 |
+
for item in pool:
|
| 175 |
+
if item["sample_idx"] == sample_idx:
|
| 176 |
+
continue
|
| 177 |
+
if item["vid"] == current["vid"] and item["fid"] != current["fid"]:
|
| 178 |
+
return item["sample_idx"]
|
| 179 |
+
for item in pool:
|
| 180 |
+
if item["sample_idx"] == sample_idx:
|
| 181 |
+
continue
|
| 182 |
+
if item["vid"] == current["vid"] and item["ref"] != current["ref"]:
|
| 183 |
+
return item["sample_idx"]
|
| 184 |
+
return None
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def run_d2_llm(model, tokenizer, pool, betas, limit):
|
| 188 |
+
rows = []
|
| 189 |
+
lookup = {item["sample_idx"]: item for item in pool}
|
| 190 |
+
generator = torch.Generator(device="cuda")
|
| 191 |
+
generator.manual_seed(1234)
|
| 192 |
+
loader = make_loader(tokenizer)
|
| 193 |
+
|
| 194 |
+
for sample_idx, batch in enumerate(loader):
|
| 195 |
+
if sample_idx >= limit:
|
| 196 |
+
break
|
| 197 |
+
batch = dict_to_cuda(batch)
|
| 198 |
+
item = lookup[sample_idx]
|
| 199 |
+
h_real = item["h"].cuda()
|
| 200 |
+
q_real = item["q"].cuda()
|
| 201 |
+
|
| 202 |
+
low_res_masks, iou_predictions = decode_low_res(model, batch, q_real)
|
| 203 |
+
llm_tokens = clip_projected_tokens(model, batch)
|
| 204 |
+
pred_mask_tokens = masks_to_token_grid(torch.sigmoid(low_res_masks), llm_tokens.shape[1])
|
| 205 |
+
gt_masks = batch["masks"][0][0].float()
|
| 206 |
+
gt_mask_tokens = masks_to_token_grid(gt_masks, llm_tokens.shape[1])
|
| 207 |
+
|
| 208 |
+
pred_logits_hr = model.get_model().visual_model.postprocess_masks(
|
| 209 |
+
low_res_masks.to(batch["image_feats"][0].dtype),
|
| 210 |
+
input_size=batch["resizes"][0],
|
| 211 |
+
original_size=batch["orgsizes"][0],
|
| 212 |
+
).squeeze(1)
|
| 213 |
+
frame_ious = frame_iou(pred_logits_hr, gt_masks)
|
| 214 |
+
frame_fscores = frame_fscore_proxy(pred_logits_hr, gt_masks)
|
| 215 |
+
pred_area = (torch.sigmoid(pred_logits_hr.float()) > 0.4).float().mean(dim=(1, 2))
|
| 216 |
+
gt_area = gt_masks.float().mean(dim=(1, 2))
|
| 217 |
+
|
| 218 |
+
shuffled_idx = choose_shuffled_idx(sample_idx, pool)
|
| 219 |
+
wrong_ref_idx = choose_wrong_ref_idx(sample_idx, pool)
|
| 220 |
+
controls = [
|
| 221 |
+
("real", h_real, sample_idx),
|
| 222 |
+
("random", torch.randn(h_real.shape, device=h_real.device, generator=generator), None),
|
| 223 |
+
]
|
| 224 |
+
if shuffled_idx is not None:
|
| 225 |
+
controls.append(("shuffled", lookup[shuffled_idx]["h"].cuda(), shuffled_idx))
|
| 226 |
+
if wrong_ref_idx is not None:
|
| 227 |
+
controls.append(("wrong_ref", lookup[wrong_ref_idx]["h"].cuda(), wrong_ref_idx))
|
| 228 |
+
|
| 229 |
+
for beta in betas:
|
| 230 |
+
for h_type, h, h_source_idx in controls:
|
| 231 |
+
pred_scores = d2_scores_llm(llm_tokens, pred_mask_tokens, h, beta)
|
| 232 |
+
gt_scores = d2_scores_llm(llm_tokens, gt_mask_tokens, h, beta)
|
| 233 |
+
for frame_idx in range(pred_scores.shape[0]):
|
| 234 |
+
rows.append(
|
| 235 |
+
{
|
| 236 |
+
"sample_idx": sample_idx,
|
| 237 |
+
"vid": item["vid"],
|
| 238 |
+
"ref": item["ref"],
|
| 239 |
+
"fid": item["fid"],
|
| 240 |
+
"split": args.eval_split,
|
| 241 |
+
"frame": frame_idx,
|
| 242 |
+
"h_type": h_type,
|
| 243 |
+
"beta": beta,
|
| 244 |
+
"s_pred": pred_scores[frame_idx].item(),
|
| 245 |
+
"s_gt": gt_scores[frame_idx].item(),
|
| 246 |
+
"h_source_idx": h_source_idx if h_source_idx is not None else "",
|
| 247 |
+
"frame_iou": frame_ious[frame_idx].item(),
|
| 248 |
+
"frame_fscore_proxy": frame_fscores[frame_idx].item(),
|
| 249 |
+
"iou_pred": iou_predictions[frame_idx].item(),
|
| 250 |
+
"pred_area": pred_area[frame_idx].item(),
|
| 251 |
+
"gt_area": gt_area[frame_idx].item(),
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
real_rows = [
|
| 256 |
+
r for r in rows if r["sample_idx"] == sample_idx and r["h_type"] == "real" and r["beta"] == betas[0]
|
| 257 |
+
]
|
| 258 |
+
s_pred_values = [r["s_pred"] for r in real_rows]
|
| 259 |
+
print(
|
| 260 |
+
f"D2-LLM {sample_idx}: vid={item['vid']} ref={item['ref']} "
|
| 261 |
+
f"mean_s_pred={np.mean(s_pred_values):.4f} min_s_pred={np.min(s_pred_values):.4f} "
|
| 262 |
+
f"mean_iou={frame_ious.mean().item():.4f}"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
return rows
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
def print_summary(rows):
|
| 269 |
+
print("\nSummary")
|
| 270 |
+
print(f"rows: {len(rows)}")
|
| 271 |
+
for beta in sorted(set(r["beta"] for r in rows)):
|
| 272 |
+
beta_rows = [r for r in rows if r["beta"] == beta]
|
| 273 |
+
print(f"\nbeta={beta}")
|
| 274 |
+
for h_type in sorted(set(r["h_type"] for r in beta_rows)):
|
| 275 |
+
hr = [r for r in beta_rows if r["h_type"] == h_type]
|
| 276 |
+
print(
|
| 277 |
+
f"{h_type:10s} "
|
| 278 |
+
f"mean_s_pred={np.mean([r['s_pred'] for r in hr]):+.4f} "
|
| 279 |
+
f"mean_s_gt={np.mean([r['s_gt'] for r in hr]):+.4f}"
|
| 280 |
+
)
|
| 281 |
+
real_rows = [r for r in beta_rows if r["h_type"] == "real"]
|
| 282 |
+
s_pred = np.array([r["s_pred"] for r in real_rows])
|
| 283 |
+
frame_iou_values = np.array([r["frame_iou"] for r in real_rows])
|
| 284 |
+
if len(s_pred) > 1 and np.std(s_pred) > 1e-8 and np.std(frame_iou_values) > 1e-8:
|
| 285 |
+
corr = np.corrcoef(s_pred, frame_iou_values)[0, 1]
|
| 286 |
+
print(f"corr(real s_pred, frame_iou)={corr:+.4f}")
|
| 287 |
+
else:
|
| 288 |
+
print("corr(real s_pred, frame_iou)=nan")
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def main():
|
| 292 |
+
set_seed(42)
|
| 293 |
+
torch.set_grad_enabled(False)
|
| 294 |
+
betas = parse_betas()
|
| 295 |
+
tokenizer, seg_token_idx = build_tokenizer()
|
| 296 |
+
limit = args.max_eval_rows if args.max_eval_rows > 0 else 30
|
| 297 |
+
print(f"Split: {args.eval_split} | samples: {limit} | betas: {betas}")
|
| 298 |
+
|
| 299 |
+
model = build_model(tokenizer, seg_token_idx)
|
| 300 |
+
pool = collect_hidden_pool(model, tokenizer, limit)
|
| 301 |
+
rows = run_d2_llm(model, tokenizer, pool, betas, limit)
|
| 302 |
+
print_summary(rows)
|
| 303 |
+
|
| 304 |
+
csv_path = os.environ.get("D2_LLM_CSV", f"/workspace/SimToken/d2_llm_{args.eval_split}_{limit}.csv")
|
| 305 |
+
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
|
| 306 |
+
with open(csv_path, "w", newline="") as f:
|
| 307 |
+
writer = csv.DictWriter(f, fieldnames=list(rows[0].keys()))
|
| 308 |
+
writer.writeheader()
|
| 309 |
+
writer.writerows(rows)
|
| 310 |
+
print(f"\nSaved CSV: {csv_path}")
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
main()
|
decoder_invariance_check.py
ADDED
|
@@ -0,0 +1,256 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import transformers
|
| 9 |
+
from peft import LoraConfig, get_peft_model
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from transformers import AutoConfig
|
| 12 |
+
|
| 13 |
+
from configs import args
|
| 14 |
+
from datasets import REFAVS
|
| 15 |
+
from load_model import collate_fn, dict_to_cuda
|
| 16 |
+
from models.avs_model import Simtoken_ForCausalLM
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def set_seed(seed=42):
|
| 20 |
+
torch.manual_seed(seed)
|
| 21 |
+
np.random.seed(seed)
|
| 22 |
+
random.seed(seed)
|
| 23 |
+
torch.cuda.manual_seed_all(seed)
|
| 24 |
+
torch.backends.cudnn.deterministic = True
|
| 25 |
+
torch.backends.cudnn.benchmark = False
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def find_lora_target_modules(model, target_modules=("q_proj", "v_proj")):
|
| 29 |
+
modules = set()
|
| 30 |
+
excluded = [
|
| 31 |
+
"visual_model",
|
| 32 |
+
"vision_tower",
|
| 33 |
+
"mm_projector",
|
| 34 |
+
"text_hidden_fcs",
|
| 35 |
+
"audio_feature_layer",
|
| 36 |
+
]
|
| 37 |
+
for name, module in model.named_modules():
|
| 38 |
+
if not isinstance(module, torch.nn.Linear):
|
| 39 |
+
continue
|
| 40 |
+
if any(x in name for x in excluded):
|
| 41 |
+
continue
|
| 42 |
+
if any(x in name for x in target_modules):
|
| 43 |
+
modules.add(name)
|
| 44 |
+
return sorted(modules)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_model(tokenizer, seg_token_idx):
|
| 48 |
+
model_args = {
|
| 49 |
+
"train_mask_decoder": True,
|
| 50 |
+
"out_dim": 256,
|
| 51 |
+
"ce_loss_weight": 1.0,
|
| 52 |
+
"dice_loss_weight": 0.5,
|
| 53 |
+
"bce_loss_weight": 2.0,
|
| 54 |
+
"seg_token_idx": seg_token_idx,
|
| 55 |
+
"vision_pretrained": args.vision_pretrained,
|
| 56 |
+
"vision_tower": args.vision_tower,
|
| 57 |
+
"use_im_start_end": False,
|
| 58 |
+
"compress": args.compress,
|
| 59 |
+
"start": args.start,
|
| 60 |
+
}
|
| 61 |
+
|
| 62 |
+
model = Simtoken_ForCausalLM.from_pretrained(
|
| 63 |
+
args.mllm,
|
| 64 |
+
torch_dtype=torch.bfloat16,
|
| 65 |
+
low_cpu_mem_usage=True,
|
| 66 |
+
**model_args,
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
model.config.eos_token_id = tokenizer.eos_token_id
|
| 70 |
+
model.config.bos_token_id = tokenizer.bos_token_id
|
| 71 |
+
model.config.pad_token_id = tokenizer.pad_token_id
|
| 72 |
+
|
| 73 |
+
model.get_model().initialize_vision_modules(model.get_model().config)
|
| 74 |
+
vision_tower = model.get_model().get_vision_tower()
|
| 75 |
+
vision_tower.to(dtype=torch.float32, device="cuda")
|
| 76 |
+
|
| 77 |
+
model_args_from_pt = AutoConfig.from_pretrained(args.mllm)
|
| 78 |
+
model_args_from_pt.use_cluster = True
|
| 79 |
+
model_args_from_pt.freeze = False
|
| 80 |
+
model_args_from_pt.mm_tune = True
|
| 81 |
+
model_args_from_pt.spatial_cluster_rate0 = 64
|
| 82 |
+
model_args_from_pt.spatial_cluster_rate1 = 32
|
| 83 |
+
model_args_from_pt.spatial_cluster_rate2 = 16
|
| 84 |
+
model_args_from_pt.temporal_cluster_rate = 0.0625
|
| 85 |
+
model_args_from_pt.vision_tune = False
|
| 86 |
+
model.get_model().initialize_cluster_modules(model_args_from_pt)
|
| 87 |
+
model.get_model().initialize_lisa_modules(model.get_model().config)
|
| 88 |
+
|
| 89 |
+
lora_config = LoraConfig(
|
| 90 |
+
r=8,
|
| 91 |
+
lora_alpha=16,
|
| 92 |
+
target_modules=find_lora_target_modules(model),
|
| 93 |
+
lora_dropout=0.05,
|
| 94 |
+
bias="none",
|
| 95 |
+
task_type="CAUSAL_LM",
|
| 96 |
+
)
|
| 97 |
+
model = get_peft_model(model, lora_config)
|
| 98 |
+
model = model.to("cuda")
|
| 99 |
+
model.resize_token_embeddings(len(tokenizer))
|
| 100 |
+
|
| 101 |
+
state = torch.load(args.saved_model, map_location="cpu")
|
| 102 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 103 |
+
print(f"Loaded checkpoint: {args.saved_model}")
|
| 104 |
+
print(f"Missing keys: {len(missing)} | Unexpected keys: {len(unexpected)}")
|
| 105 |
+
|
| 106 |
+
model.eval()
|
| 107 |
+
return model
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_seg_embedding(model, batch):
|
| 111 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 112 |
+
output = model.forward(
|
| 113 |
+
images=batch["images"],
|
| 114 |
+
images_clip=batch["images_clip"],
|
| 115 |
+
audio_features=batch["audio_feats"],
|
| 116 |
+
image_features=batch["image_feats"],
|
| 117 |
+
input_ids=batch["input_ids"],
|
| 118 |
+
labels=batch["labels"],
|
| 119 |
+
attention_masks=batch["attention_masks"],
|
| 120 |
+
masks_list=batch["masks"],
|
| 121 |
+
resize_list=batch["resizes"],
|
| 122 |
+
orgsize_list=batch["orgsizes"],
|
| 123 |
+
conversation_list=batch["convs"],
|
| 124 |
+
refs_num=batch["refs_num"],
|
| 125 |
+
fids=batch["fids"],
|
| 126 |
+
vids=batch["vids"],
|
| 127 |
+
contrast=args.ct_weight,
|
| 128 |
+
ref_ids=batch["ref_ids"],
|
| 129 |
+
inference=True,
|
| 130 |
+
)
|
| 131 |
+
return output["seg_embeddings"][0][0:1]
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def check_one_sample(model, batch):
|
| 135 |
+
q = get_seg_embedding(model, batch)
|
| 136 |
+
image_embeddings = batch["image_feats"][0]
|
| 137 |
+
|
| 138 |
+
visual_model = model.get_model().visual_model
|
| 139 |
+
sparse, dense = visual_model.prompt_encoder(
|
| 140 |
+
points=None,
|
| 141 |
+
boxes=None,
|
| 142 |
+
masks=None,
|
| 143 |
+
text_embeds=q.unsqueeze(1),
|
| 144 |
+
)
|
| 145 |
+
sparse = sparse.to(q.dtype)
|
| 146 |
+
dense = dense.to(q.dtype)
|
| 147 |
+
|
| 148 |
+
decoder = visual_model.mask_decoder
|
| 149 |
+
image_pe = visual_model.prompt_encoder.get_dense_pe()
|
| 150 |
+
|
| 151 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 152 |
+
full_masks, full_iou = decoder(
|
| 153 |
+
image_embeddings=image_embeddings,
|
| 154 |
+
image_pe=image_pe,
|
| 155 |
+
sparse_prompt_embeddings=sparse,
|
| 156 |
+
dense_prompt_embeddings=dense,
|
| 157 |
+
multimask_output=False,
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
rows = []
|
| 161 |
+
for t in range(image_embeddings.shape[0]):
|
| 162 |
+
single_masks, single_iou = decoder(
|
| 163 |
+
image_embeddings=image_embeddings[t : t + 1],
|
| 164 |
+
image_pe=image_pe,
|
| 165 |
+
sparse_prompt_embeddings=sparse,
|
| 166 |
+
dense_prompt_embeddings=dense,
|
| 167 |
+
multimask_output=False,
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
diff = (full_masks[t : t + 1] - single_masks).float().abs()
|
| 171 |
+
iou_diff = (full_iou[t : t + 1] - single_iou).float().abs()
|
| 172 |
+
rows.append(
|
| 173 |
+
{
|
| 174 |
+
"vid": batch["vids"][0],
|
| 175 |
+
"ref": batch["refs"][0][0],
|
| 176 |
+
"frame": t,
|
| 177 |
+
"max_abs_diff": diff.max().item(),
|
| 178 |
+
"mean_abs_diff": diff.mean().item(),
|
| 179 |
+
"iou_pred_diff": iou_diff.max().item(),
|
| 180 |
+
}
|
| 181 |
+
)
|
| 182 |
+
return rows
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def main():
|
| 186 |
+
set_seed(42)
|
| 187 |
+
torch.set_grad_enabled(False)
|
| 188 |
+
|
| 189 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 190 |
+
args.mllm,
|
| 191 |
+
cache_dir=None,
|
| 192 |
+
model_max_length=2048,
|
| 193 |
+
padding_side="right",
|
| 194 |
+
use_fast=False,
|
| 195 |
+
)
|
| 196 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 197 |
+
tokenizer.add_tokens("[SEG]")
|
| 198 |
+
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 199 |
+
|
| 200 |
+
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
|
| 201 |
+
loader = DataLoader(
|
| 202 |
+
dataset,
|
| 203 |
+
batch_size=1,
|
| 204 |
+
shuffle=False,
|
| 205 |
+
num_workers=0,
|
| 206 |
+
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
limit = args.max_eval_rows if args.max_eval_rows > 0 else 1
|
| 210 |
+
print(f"Split: {args.eval_split} | samples to check: {limit}")
|
| 211 |
+
|
| 212 |
+
model = build_model(tokenizer, seg_token_idx)
|
| 213 |
+
|
| 214 |
+
all_rows = []
|
| 215 |
+
for sample_idx, batch in enumerate(loader):
|
| 216 |
+
if sample_idx >= limit:
|
| 217 |
+
break
|
| 218 |
+
batch = dict_to_cuda(batch)
|
| 219 |
+
rows = check_one_sample(model, batch)
|
| 220 |
+
all_rows.extend(rows)
|
| 221 |
+
|
| 222 |
+
print(f"\nSample {sample_idx}: vid={rows[0]['vid']} ref={rows[0]['ref']}")
|
| 223 |
+
print("frame | max_abs_diff | mean_abs_diff | iou_pred_diff")
|
| 224 |
+
for row in rows:
|
| 225 |
+
print(
|
| 226 |
+
f"{row['frame']:02d} | "
|
| 227 |
+
f"{row['max_abs_diff']:.8e} | "
|
| 228 |
+
f"{row['mean_abs_diff']:.8e} | "
|
| 229 |
+
f"{row['iou_pred_diff']:.8e}"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
if not all_rows:
|
| 233 |
+
raise RuntimeError("No rows were checked. Is the selected split empty?")
|
| 234 |
+
|
| 235 |
+
max_diff = max(row["max_abs_diff"] for row in all_rows)
|
| 236 |
+
mean_diff = sum(row["mean_abs_diff"] for row in all_rows) / len(all_rows)
|
| 237 |
+
max_iou_diff = max(row["iou_pred_diff"] for row in all_rows)
|
| 238 |
+
|
| 239 |
+
print("\nSummary")
|
| 240 |
+
print(f"checked frames: {len(all_rows)}")
|
| 241 |
+
print(f"global max_abs_diff: {max_diff:.8e}")
|
| 242 |
+
print(f"average mean_abs_diff: {mean_diff:.8e}")
|
| 243 |
+
print(f"global max_iou_pred_diff: {max_iou_diff:.8e}")
|
| 244 |
+
|
| 245 |
+
csv_path = os.environ.get("DECODER_INVARIANCE_CSV")
|
| 246 |
+
if csv_path:
|
| 247 |
+
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
|
| 248 |
+
with open(csv_path, "w", newline="") as f:
|
| 249 |
+
writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys()))
|
| 250 |
+
writer.writeheader()
|
| 251 |
+
writer.writerows(all_rows)
|
| 252 |
+
print(f"Saved CSV: {csv_path}")
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
if __name__ == "__main__":
|
| 256 |
+
main()
|
load_model.py
CHANGED
|
@@ -1,12 +1,3 @@
|
|
| 1 |
-
# Compatibility: transformers==4.30.2 calls hf_hub_download(use_auth_token=...),
|
| 2 |
-
# removed in huggingface_hub>=0.20. Patch before importing transformers so the
|
| 3 |
-
# bound reference inside transformers.utils.hub picks up the fixed version.
|
| 4 |
-
import huggingface_hub as _hfhub
|
| 5 |
-
_hfhub_orig = _hfhub.hf_hub_download
|
| 6 |
-
def _hfhub_compat(*args, use_auth_token=None, token=None, **kwargs):
|
| 7 |
-
return _hfhub_orig(*args, token=token or use_auth_token, **kwargs)
|
| 8 |
-
_hfhub.hf_hub_download = _hfhub_compat
|
| 9 |
-
|
| 10 |
import transformers
|
| 11 |
|
| 12 |
from torch.cuda.amp import autocast, GradScaler
|
|
@@ -217,7 +208,7 @@ def collate_fn(batch, tokenizer=None):
|
|
| 217 |
|
| 218 |
import torch.multiprocessing as mp
|
| 219 |
if __name__ == "__main__":
|
| 220 |
-
mp.set_start_method("spawn"
|
| 221 |
set_seed(42)
|
| 222 |
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 223 |
args.mllm,
|
|
@@ -233,9 +224,17 @@ if __name__ == "__main__":
|
|
| 233 |
print("seg_token_idx: ", seg_token_idx)
|
| 234 |
|
| 235 |
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
|
| 240 |
|
| 241 |
|
|
@@ -341,8 +340,12 @@ if __name__ == "__main__":
|
|
| 341 |
model = model.to("cuda")
|
| 342 |
model.resize_token_embeddings(len(tokenizer))
|
| 343 |
|
| 344 |
-
model.load_state_dict(
|
| 345 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 346 |
|
| 347 |
|
| 348 |
save_root = args.visualization_root
|
|
@@ -401,16 +404,15 @@ if __name__ == "__main__":
|
|
| 401 |
print("visualization finished")
|
| 402 |
|
| 403 |
|
| 404 |
-
def valuate(model, dataloader, name
|
| 405 |
model.eval()
|
| 406 |
|
| 407 |
total_iou = 0
|
| 408 |
total_fscore = 0
|
| 409 |
count = 0
|
| 410 |
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
if 0 < max_rows <= i:
|
| 414 |
break
|
| 415 |
input_dict = dict_to_cuda(batch)
|
| 416 |
|
|
@@ -445,39 +447,40 @@ if __name__ == "__main__":
|
|
| 445 |
total_fscore += fscore * num_seg * T
|
| 446 |
count += num_seg * T
|
| 447 |
|
|
|
|
|
|
|
|
|
|
| 448 |
print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
|
| 449 |
|
| 450 |
|
| 451 |
-
def valuate_Null(model, dataloader
|
| 452 |
model.eval()
|
| 453 |
|
| 454 |
total_metric = 0
|
| 455 |
count = 0
|
| 456 |
|
| 457 |
-
|
| 458 |
-
|
| 459 |
-
if 0 < max_rows <= i:
|
| 460 |
break
|
| 461 |
input_dict = dict_to_cuda(batch)
|
| 462 |
-
with torch.
|
| 463 |
-
|
| 464 |
-
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
|
| 474 |
-
|
| 475 |
-
|
| 476 |
-
|
| 477 |
-
|
| 478 |
-
|
| 479 |
-
|
| 480 |
-
inference=True)
|
| 481 |
pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
|
| 482 |
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
|
| 483 |
for i in range(len(pred_masks)):
|
|
@@ -488,637 +491,13 @@ if __name__ == "__main__":
|
|
| 488 |
total_metric += null_metric * num_seg * T
|
| 489 |
count += num_seg * T
|
| 490 |
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
|
| 494 |
|
|
|
|
| 495 |
|
| 496 |
-
from seg_ltpo import (
|
| 497 |
-
LTPOConfig, ltpo_optimize, best_of_2_optimize, decode_full_video,
|
| 498 |
-
get_sam_model, get_anchor_indices,
|
| 499 |
-
QLTPOConfig, q_ltpo_autograd, check_grad_connectivity,
|
| 500 |
-
reset_q_ltpo_stats, get_q_ltpo_stats,
|
| 501 |
-
q_ltpo_frame_adaptive, decode_full_video_adaptive,
|
| 502 |
-
_compute_avt_proxy_reward,
|
| 503 |
-
)
|
| 504 |
-
|
| 505 |
-
def print_q_ltpo_stats(name: str) -> None:
|
| 506 |
-
stats = get_q_ltpo_stats()
|
| 507 |
-
if not stats:
|
| 508 |
-
return
|
| 509 |
-
n = len(stats)
|
| 510 |
-
acc_rate = sum(s["accepted"] for s in stats) / n
|
| 511 |
-
mean_gain = sum(s["reward_gain"] for s in stats) / n
|
| 512 |
-
mean_drift = sum(s["drift"] for s in stats) / n
|
| 513 |
-
clip_rate = sum(s["hit_clip"] for s in stats) / n
|
| 514 |
-
mean_iou_init = sum(s["R_iou_pred_init"] for s in stats) / n
|
| 515 |
-
mean_iou_best = sum(s["R_iou_pred_best"] for s in stats) / n
|
| 516 |
-
mean_area_init = sum(s["area_hard_init"] for s in stats) / n
|
| 517 |
-
mean_area_best = sum(s["area_hard_best"] for s in stats) / n
|
| 518 |
-
# Null safety: reward improved but predicted area grew >20 %
|
| 519 |
-
null_risk = sum(
|
| 520 |
-
1 for s in stats
|
| 521 |
-
if s["reward_gain"] > 0 and s["area_hard_best"] > s["area_hard_init"] * 1.2
|
| 522 |
-
) / n
|
| 523 |
-
gains = sorted(s["reward_gain"] for s in stats)
|
| 524 |
-
def _pct(v, p): return v[max(0, int(len(v) * p / 100) - 1)]
|
| 525 |
-
mean_e0 = sum(s["e0"] for s in stats) / n
|
| 526 |
-
mean_mask_iou = sum(s.get("mask_soft_iou", 0.0) for s in stats) / n
|
| 527 |
-
mean_iou_contrib = sum(s.get("R_iou_contrib_gain", 0.0) for s in stats) / n
|
| 528 |
-
mean_soft_area_init = sum(s.get("r_area_soft_init", 0.0) for s in stats) / n
|
| 529 |
-
mean_soft_area_best = sum(s.get("r_area_soft_best", 0.0) for s in stats) / n
|
| 530 |
-
# B1 activation diagnostics
|
| 531 |
-
b1_excesses = sorted(s.get("b1_peak_excess", 0.0) for s in stats)
|
| 532 |
-
b1_act_rate = sum(1 for v in b1_excesses if v > 1e-8) / n
|
| 533 |
-
b1_mean_excess = sum(b1_excesses) / n
|
| 534 |
-
print(f"\n [q-LTPO stats | {name} | n={n}]")
|
| 535 |
-
print(f" acceptance rate : {acc_rate:.3f}")
|
| 536 |
-
print(f" mean e0 (exist prior): {mean_e0:.4f} ← should differ Null vs Seen")
|
| 537 |
-
print(f" mean reward gain : {mean_gain:+.4f}")
|
| 538 |
-
print(f" reward_gain p10/50/90: {_pct(gains,10):+.4f} / {_pct(gains,50):+.4f} / {_pct(gains,90):+.4f}")
|
| 539 |
-
print(f" mean drift ‖q−q₀‖ : {mean_drift:.4f}")
|
| 540 |
-
print(f" hit-clip ratio : {clip_rate:.3f}")
|
| 541 |
-
print(f" R_iou_pred init→best : {mean_iou_init:.4f} → {mean_iou_best:.4f}")
|
| 542 |
-
print(f" R_iou_contrib_gain : {mean_iou_contrib:+.4f} ← λ_iou·e0·Δiou")
|
| 543 |
-
print(f" mask soft-IoU(init,best): {mean_mask_iou:.4f} ← 1.0=mask不变")
|
| 544 |
-
print(f" area (hard) init→best: {mean_area_init:.4f} → {mean_area_best:.4f}")
|
| 545 |
-
print(f" soft area init→best : {mean_soft_area_init:.4f} → {mean_soft_area_best:.4f}")
|
| 546 |
-
print(f" B1 activation rate : {b1_act_rate:.3f} ← frac(peak_area > e0)")
|
| 547 |
-
print(f" B1 mean excess : {b1_mean_excess:.5f} ← mean ReLU(peak_area - e0)")
|
| 548 |
-
print(f" B1 excess p10/50/90 : {_pct(b1_excesses,10):.5f} / {_pct(b1_excesses,50):.5f} / {_pct(b1_excesses,90):.5f}")
|
| 549 |
-
print(f" reward↑ & area+20%↑ : {null_risk:.3f} ← Null safety indicator")
|
| 550 |
-
# Direction II: frame-adaptive delta diagnostics
|
| 551 |
-
delta_norms = [s.get("delta_norm", 0.0) for s in stats]
|
| 552 |
-
if any(v > 0 for v in delta_norms):
|
| 553 |
-
print(f" mean delta ‖Δ‖ : {sum(delta_norms)/n:.4f} ← per-anchor residual norm")
|
| 554 |
-
|
| 555 |
-
def valuate_ltpo(model, dataloader, name, ltpo_cfg, optimize_fn=None,
|
| 556 |
-
max_rows=-1, multimask=False, use_edge=False):
|
| 557 |
-
if optimize_fn is None:
|
| 558 |
-
optimize_fn = ltpo_optimize
|
| 559 |
-
"""
|
| 560 |
-
Evaluate with SEG-LTPO test-time optimisation + optional boundary refinement.
|
| 561 |
-
|
| 562 |
-
decode_mode:
|
| 563 |
-
multimask=False, use_edge=False : original single-mask decode (default)
|
| 564 |
-
multimask=True, use_edge=False : 3 candidates, SAM iou_pred selection (step 1a)
|
| 565 |
-
multimask=True, use_edge=True : 3 candidates, boundary-edge score (step 1b)
|
| 566 |
-
"""
|
| 567 |
-
model.eval()
|
| 568 |
-
sam_model = get_sam_model(model)
|
| 569 |
-
model_dtype = torch.bfloat16
|
| 570 |
-
num_frames = 10
|
| 571 |
-
anchor_indices = get_anchor_indices(num_frames, ltpo_cfg.num_anchors)
|
| 572 |
-
|
| 573 |
-
total_iou = 0
|
| 574 |
-
total_fscore = 0
|
| 575 |
-
count = 0
|
| 576 |
-
|
| 577 |
-
_total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
|
| 578 |
-
for i, batch in enumerate(tqdm(dataloader, desc=f"LTPO Evaluating on {name}", total=_total)):
|
| 579 |
-
if 0 < max_rows <= i:
|
| 580 |
-
break
|
| 581 |
-
input_dict = dict_to_cuda(batch)
|
| 582 |
-
|
| 583 |
-
# ── Step 1: standard forward pass (LLM + SAM decode) ──────────
|
| 584 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
|
| 585 |
-
with torch.no_grad():
|
| 586 |
-
output_dict = model.forward(
|
| 587 |
-
images=input_dict["images"],
|
| 588 |
-
images_clip=input_dict["images_clip"],
|
| 589 |
-
audio_features=input_dict["audio_feats"],
|
| 590 |
-
image_features=input_dict["image_feats"],
|
| 591 |
-
input_ids=input_dict["input_ids"],
|
| 592 |
-
labels=input_dict["labels"],
|
| 593 |
-
attention_masks=input_dict["attention_masks"],
|
| 594 |
-
masks_list=input_dict["masks"],
|
| 595 |
-
resize_list=input_dict["resizes"],
|
| 596 |
-
orgsize_list=input_dict["orgsizes"],
|
| 597 |
-
conversation_list=input_dict["convs"],
|
| 598 |
-
refs_num=input_dict["refs_num"],
|
| 599 |
-
fids=input_dict["fids"],
|
| 600 |
-
vids=input_dict["vids"],
|
| 601 |
-
contrast=args.ct_weight,
|
| 602 |
-
ref_ids=input_dict["ref_ids"],
|
| 603 |
-
inference=True,
|
| 604 |
-
)
|
| 605 |
-
|
| 606 |
-
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
|
| 607 |
-
seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
|
| 608 |
-
|
| 609 |
-
for b in range(len(input_dict["images"])):
|
| 610 |
-
image_embeds_b = input_dict["image_feats"][b] # [T, 256, 64, 64]
|
| 611 |
-
resize_b = input_dict["resizes"][b]
|
| 612 |
-
orgsize_b = input_dict["orgsizes"][b]
|
| 613 |
-
rgb_b = input_dict["images"][b] if use_edge else None # [T,3,H,W]
|
| 614 |
-
|
| 615 |
-
# Convert initial Fseg to float32 for stable optimisation.
|
| 616 |
-
# seg_emb_list[b]: [num_seg, 256] in bfloat16
|
| 617 |
-
F_init_b = seg_emb_list[b].detach().float() # [num_seg, 256]
|
| 618 |
-
|
| 619 |
-
pred_masks_ltpo = []
|
| 620 |
-
for seg_idx in range(F_init_b.shape[0]):
|
| 621 |
-
fseg_init = F_init_b[seg_idx : seg_idx + 1] # [1, 256]
|
| 622 |
-
|
| 623 |
-
# ── Step 2: optimisation (float32, outside autocast) ──────
|
| 624 |
-
best_fseg = optimize_fn(
|
| 625 |
-
fseg_init, image_embeds_b, anchor_indices,
|
| 626 |
-
sam_model, model_dtype, ltpo_cfg,
|
| 627 |
-
) # [1, 256] float32
|
| 628 |
-
|
| 629 |
-
# ── Step 3: decode full video with best Fseg ──────────────
|
| 630 |
-
pred_mask = decode_full_video(
|
| 631 |
-
best_fseg, image_embeds_b, sam_model,
|
| 632 |
-
resize_b, orgsize_b, model_dtype,
|
| 633 |
-
rgb_frames=rgb_b, multimask=multimask,
|
| 634 |
-
) # [T, H, W]
|
| 635 |
-
pred_masks_ltpo.append(pred_mask)
|
| 636 |
-
|
| 637 |
-
pred_masks_b = torch.stack(pred_masks_ltpo, dim=0) # [num_seg, T, H, W]
|
| 638 |
-
|
| 639 |
-
num_seg = pred_masks_b.shape[0]
|
| 640 |
-
T_ = pred_masks_b.shape[1]
|
| 641 |
-
iou = utility.mask_iou(pred_masks_b, gt_masks[b])
|
| 642 |
-
fscore = utility.Eval_Fmeasure(pred_masks_b, gt_masks[b], None)
|
| 643 |
-
|
| 644 |
-
total_iou += iou * num_seg * T_
|
| 645 |
-
total_fscore += fscore * num_seg * T_
|
| 646 |
-
count += num_seg * T_
|
| 647 |
-
|
| 648 |
-
print(f"\n LTPO valuate on {name}: miou: {total_iou/count:.4f} fscore: {total_fscore/count:.4f}")
|
| 649 |
-
|
| 650 |
-
|
| 651 |
-
def valuate_ltpo_null(model, dataloader, ltpo_cfg, optimize_fn=None, max_rows=-1):
|
| 652 |
-
if optimize_fn is None:
|
| 653 |
-
optimize_fn = ltpo_optimize
|
| 654 |
-
"""LTPO evaluation for Null split: measures S metric (lower = fewer false-positive masks)."""
|
| 655 |
-
model.eval()
|
| 656 |
-
sam_model = get_sam_model(model)
|
| 657 |
-
model_dtype = torch.bfloat16
|
| 658 |
-
num_frames = 10
|
| 659 |
-
anchor_indices = get_anchor_indices(num_frames, ltpo_cfg.num_anchors)
|
| 660 |
-
|
| 661 |
-
total_metric = 0
|
| 662 |
-
count = 0
|
| 663 |
-
|
| 664 |
-
_total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
|
| 665 |
-
for i, batch in enumerate(tqdm(dataloader, desc="LTPO Evaluating on Null", total=_total)):
|
| 666 |
-
if 0 < max_rows <= i:
|
| 667 |
-
break
|
| 668 |
-
input_dict = dict_to_cuda(batch)
|
| 669 |
-
|
| 670 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
|
| 671 |
-
with torch.no_grad():
|
| 672 |
-
output_dict = model.forward(
|
| 673 |
-
images=input_dict["images"],
|
| 674 |
-
images_clip=input_dict["images_clip"],
|
| 675 |
-
audio_features=input_dict["audio_feats"],
|
| 676 |
-
image_features=input_dict["image_feats"],
|
| 677 |
-
input_ids=input_dict["input_ids"],
|
| 678 |
-
labels=input_dict["labels"],
|
| 679 |
-
attention_masks=input_dict["attention_masks"],
|
| 680 |
-
masks_list=input_dict["masks"],
|
| 681 |
-
resize_list=input_dict["resizes"],
|
| 682 |
-
orgsize_list=input_dict["orgsizes"],
|
| 683 |
-
conversation_list=input_dict["convs"],
|
| 684 |
-
refs_num=input_dict["refs_num"],
|
| 685 |
-
fids=input_dict["fids"],
|
| 686 |
-
vids=input_dict["vids"],
|
| 687 |
-
contrast=args.ct_weight,
|
| 688 |
-
ref_ids=input_dict["ref_ids"],
|
| 689 |
-
inference=True,
|
| 690 |
-
)
|
| 691 |
-
|
| 692 |
-
seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
|
| 693 |
-
|
| 694 |
-
for b in range(len(input_dict["images"])):
|
| 695 |
-
image_embeds_b = input_dict["image_feats"][b]
|
| 696 |
-
resize_b = input_dict["resizes"][b]
|
| 697 |
-
orgsize_b = input_dict["orgsizes"][b]
|
| 698 |
-
F_init_b = seg_emb_list[b].detach().float()
|
| 699 |
-
|
| 700 |
-
pred_masks_ltpo = []
|
| 701 |
-
for seg_idx in range(F_init_b.shape[0]):
|
| 702 |
-
fseg_init = F_init_b[seg_idx : seg_idx + 1]
|
| 703 |
-
best_fseg = optimize_fn(
|
| 704 |
-
fseg_init, image_embeds_b, anchor_indices,
|
| 705 |
-
sam_model, model_dtype, ltpo_cfg,
|
| 706 |
-
)
|
| 707 |
-
pred_mask = decode_full_video(
|
| 708 |
-
best_fseg, image_embeds_b, sam_model,
|
| 709 |
-
resize_b, orgsize_b, model_dtype,
|
| 710 |
-
)
|
| 711 |
-
pred_masks_ltpo.append(pred_mask)
|
| 712 |
-
|
| 713 |
-
pred_masks_b = torch.stack(pred_masks_ltpo, dim=0) # [num_seg, T, H, W]
|
| 714 |
-
num_seg = pred_masks_b.shape[0]
|
| 715 |
-
T_ = pred_masks_b.shape[1]
|
| 716 |
-
null_metric = utility.metric_s_for_null(pred_masks_b)
|
| 717 |
-
|
| 718 |
-
total_metric += null_metric * num_seg * T_
|
| 719 |
-
count += num_seg * T_
|
| 720 |
-
|
| 721 |
-
print(f"\n LTPO valuate on Null: S metric: {total_metric/count:.4f}")
|
| 722 |
-
|
| 723 |
-
|
| 724 |
-
def valuate_ltpo_adaptive(model, dataloader, name, ltpo_cfg, max_rows=-1):
|
| 725 |
-
"""Evaluate with Direction II frame-adaptive token optimization."""
|
| 726 |
-
model.eval()
|
| 727 |
-
sam_model = get_sam_model(model)
|
| 728 |
-
model_dtype = torch.bfloat16
|
| 729 |
-
num_frames = 10
|
| 730 |
-
anchor_indices = get_anchor_indices(num_frames, ltpo_cfg.num_anchors)
|
| 731 |
-
|
| 732 |
-
total_iou = 0
|
| 733 |
-
total_fscore = 0
|
| 734 |
-
count = 0
|
| 735 |
-
|
| 736 |
-
_total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
|
| 737 |
-
for i, batch in enumerate(tqdm(dataloader, desc=f"FA-LTPO Evaluating on {name}", total=_total)):
|
| 738 |
-
if 0 < max_rows <= i:
|
| 739 |
-
break
|
| 740 |
-
input_dict = dict_to_cuda(batch)
|
| 741 |
-
|
| 742 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
|
| 743 |
-
with torch.no_grad():
|
| 744 |
-
output_dict = model.forward(
|
| 745 |
-
images=input_dict["images"],
|
| 746 |
-
images_clip=input_dict["images_clip"],
|
| 747 |
-
audio_features=input_dict["audio_feats"],
|
| 748 |
-
image_features=input_dict["image_feats"],
|
| 749 |
-
input_ids=input_dict["input_ids"],
|
| 750 |
-
labels=input_dict["labels"],
|
| 751 |
-
attention_masks=input_dict["attention_masks"],
|
| 752 |
-
masks_list=input_dict["masks"],
|
| 753 |
-
resize_list=input_dict["resizes"],
|
| 754 |
-
orgsize_list=input_dict["orgsizes"],
|
| 755 |
-
conversation_list=input_dict["convs"],
|
| 756 |
-
refs_num=input_dict["refs_num"],
|
| 757 |
-
fids=input_dict["fids"],
|
| 758 |
-
vids=input_dict["vids"],
|
| 759 |
-
contrast=args.ct_weight,
|
| 760 |
-
ref_ids=input_dict["ref_ids"],
|
| 761 |
-
inference=True,
|
| 762 |
-
)
|
| 763 |
-
|
| 764 |
-
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
|
| 765 |
-
seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
|
| 766 |
-
|
| 767 |
-
for b in range(len(input_dict["images"])):
|
| 768 |
-
image_embeds_b = input_dict["image_feats"][b]
|
| 769 |
-
resize_b = input_dict["resizes"][b]
|
| 770 |
-
orgsize_b = input_dict["orgsizes"][b]
|
| 771 |
-
F_init_b = seg_emb_list[b].detach().float()
|
| 772 |
-
|
| 773 |
-
pred_masks_ltpo = []
|
| 774 |
-
for seg_idx in range(F_init_b.shape[0]):
|
| 775 |
-
fseg_init = F_init_b[seg_idx : seg_idx + 1]
|
| 776 |
-
|
| 777 |
-
q_global, delta = q_ltpo_frame_adaptive(
|
| 778 |
-
fseg_init, image_embeds_b, anchor_indices,
|
| 779 |
-
sam_model, model_dtype, ltpo_cfg,
|
| 780 |
-
)
|
| 781 |
-
|
| 782 |
-
pred_mask = decode_full_video_adaptive(
|
| 783 |
-
q_global, delta, anchor_indices,
|
| 784 |
-
image_embeds_b, sam_model,
|
| 785 |
-
resize_b, orgsize_b, model_dtype,
|
| 786 |
-
)
|
| 787 |
-
pred_masks_ltpo.append(pred_mask)
|
| 788 |
-
|
| 789 |
-
pred_masks_b = torch.stack(pred_masks_ltpo, dim=0)
|
| 790 |
-
num_seg = pred_masks_b.shape[0]
|
| 791 |
-
T_ = pred_masks_b.shape[1]
|
| 792 |
-
iou = utility.mask_iou(pred_masks_b, gt_masks[b])
|
| 793 |
-
fscore = utility.Eval_Fmeasure(pred_masks_b, gt_masks[b], None)
|
| 794 |
-
|
| 795 |
-
total_iou += iou * num_seg * T_
|
| 796 |
-
total_fscore += fscore * num_seg * T_
|
| 797 |
-
count += num_seg * T_
|
| 798 |
-
|
| 799 |
-
print(f"\n FA-LTPO valuate on {name}: miou: {total_iou/count:.4f} fscore: {total_fscore/count:.4f}")
|
| 800 |
-
|
| 801 |
-
# ── Step A0: reward–metric correlation study ─────────────────────────
|
| 802 |
-
|
| 803 |
-
def _print_correlation_report(per_sample: list) -> None:
|
| 804 |
-
import numpy as np
|
| 805 |
-
n = len(per_sample)
|
| 806 |
-
if n == 0:
|
| 807 |
-
return
|
| 808 |
-
|
| 809 |
-
r_iou = np.array([s["reward_gain"] for s in per_sample], dtype=float)
|
| 810 |
-
r_avt = np.array([s["r_avt_gain"] for s in per_sample], dtype=float)
|
| 811 |
-
r_avt_c = np.array([s["r_avt_c_gain"] for s in per_sample], dtype=float)
|
| 812 |
-
dm = np.array([s["delta_miou"] for s in per_sample], dtype=float)
|
| 813 |
-
df = np.array([s["delta_f"] for s in per_sample], dtype=float)
|
| 814 |
-
|
| 815 |
-
def pearson(x, y):
|
| 816 |
-
x = x - x.mean(); y = y - y.mean()
|
| 817 |
-
denom = np.sqrt((x ** 2).sum() * (y ** 2).sum())
|
| 818 |
-
return float((x * y).sum() / (denom + 1e-12))
|
| 819 |
-
|
| 820 |
-
def wrong_frac(gains, deltas):
|
| 821 |
-
return sum(1 for g, d in zip(gains, deltas) if g > 0 and d < 0) / n
|
| 822 |
-
|
| 823 |
-
print(f"\n [Step A0: Reward–Metric Correlation | n={n}]")
|
| 824 |
-
print(f" mean ΔmIoU : {dm.mean():+.4f} (std {dm.std():.4f})")
|
| 825 |
-
print(f" mean ΔF : {df.mean():+.4f} (std {df.std():.4f})")
|
| 826 |
-
print(f"\n Pearson r with ΔmIoU :")
|
| 827 |
-
print(f" R_iou_pred_gain : {pearson(r_iou, dm):+.3f} ← current proxy")
|
| 828 |
-
print(f" R_avt_gain : {pearson(r_avt, dm):+.3f} ← cos(z_in, q_init)")
|
| 829 |
-
print(f" R_avt_c_gain : {pearson(r_avt_c, dm):+.3f} ← cos(z_in,q)-β·cos(z_out,q)")
|
| 830 |
-
print(f"\n Pearson r with ΔF :")
|
| 831 |
-
print(f" R_iou_pred_gain : {pearson(r_iou, df):+.3f}")
|
| 832 |
-
print(f" R_avt_gain : {pearson(r_avt, df):+.3f}")
|
| 833 |
-
print(f" R_avt_c_gain : {pearson(r_avt_c, df):+.3f}")
|
| 834 |
-
print(f"\n Wrong direction (gain>0 but Δ<0):")
|
| 835 |
-
print(f" R_iou / ΔmIoU : {wrong_frac(r_iou, dm):.3f}")
|
| 836 |
-
print(f" R_avt / ΔmIoU : {wrong_frac(r_avt, dm):.3f}")
|
| 837 |
-
print(f" R_iou / ΔF : {wrong_frac(r_iou, df):.3f}")
|
| 838 |
-
print(f" R_avt / ΔF : {wrong_frac(r_avt, df):.3f}")
|
| 839 |
-
|
| 840 |
-
def valuate_ltpo_correlation_study(model, dataloader, ltpo_cfg, max_rows=-1):
|
| 841 |
-
"""Step A0: per-sample reward–metric correlation study.
|
| 842 |
-
|
| 843 |
-
For each (video, segment) sample runs:
|
| 844 |
-
1. Baseline decode (q_init → mask → IoU/F)
|
| 845 |
-
2. q-LTPO s1 (q_best → mask → IoU/F)
|
| 846 |
-
Records reward signals and ΔmIoU / ΔF per sample, then prints
|
| 847 |
-
Pearson correlation table to identify which reward best predicts
|
| 848 |
-
actual metric improvement.
|
| 849 |
-
"""
|
| 850 |
-
model.eval()
|
| 851 |
-
sam_model = get_sam_model(model)
|
| 852 |
-
model_dtype = torch.bfloat16
|
| 853 |
-
anchor_indices = get_anchor_indices(10, ltpo_cfg.num_anchors)
|
| 854 |
-
|
| 855 |
-
per_sample = []
|
| 856 |
-
|
| 857 |
-
_total = min(max_rows, len(dataloader)) if max_rows > 0 else len(dataloader)
|
| 858 |
-
for i, batch in enumerate(
|
| 859 |
-
tqdm(dataloader, desc="Correlation study (s1)", total=_total)
|
| 860 |
-
):
|
| 861 |
-
if 0 < max_rows <= i:
|
| 862 |
-
break
|
| 863 |
-
input_dict = dict_to_cuda(batch)
|
| 864 |
-
|
| 865 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
|
| 866 |
-
with torch.no_grad():
|
| 867 |
-
output_dict = model.forward(
|
| 868 |
-
images=input_dict["images"],
|
| 869 |
-
images_clip=input_dict["images_clip"],
|
| 870 |
-
audio_features=input_dict["audio_feats"],
|
| 871 |
-
image_features=input_dict["image_feats"],
|
| 872 |
-
input_ids=input_dict["input_ids"],
|
| 873 |
-
labels=input_dict["labels"],
|
| 874 |
-
attention_masks=input_dict["attention_masks"],
|
| 875 |
-
masks_list=input_dict["masks"],
|
| 876 |
-
resize_list=input_dict["resizes"],
|
| 877 |
-
orgsize_list=input_dict["orgsizes"],
|
| 878 |
-
conversation_list=input_dict["convs"],
|
| 879 |
-
refs_num=input_dict["refs_num"],
|
| 880 |
-
fids=input_dict["fids"],
|
| 881 |
-
vids=input_dict["vids"],
|
| 882 |
-
contrast=args.ct_weight,
|
| 883 |
-
ref_ids=input_dict["ref_ids"],
|
| 884 |
-
inference=True,
|
| 885 |
-
)
|
| 886 |
-
|
| 887 |
-
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
|
| 888 |
-
seg_emb_list = output_dict["seg_embeddings"] # list[B]:[num_seg, 256]
|
| 889 |
-
|
| 890 |
-
for b in range(len(input_dict["images"])):
|
| 891 |
-
image_embeds_b = input_dict["image_feats"][b]
|
| 892 |
-
resize_b = input_dict["resizes"][b]
|
| 893 |
-
orgsize_b = input_dict["orgsizes"][b]
|
| 894 |
-
F_init_b = seg_emb_list[b].detach().float()
|
| 895 |
-
|
| 896 |
-
for seg_idx in range(F_init_b.shape[0]):
|
| 897 |
-
q_init = F_init_b[seg_idx : seg_idx + 1] # [1, 256]
|
| 898 |
-
gt_seg = gt_masks[b][seg_idx : seg_idx + 1] # [1, T, H, W]
|
| 899 |
-
|
| 900 |
-
# Baseline decode (q_init, no LTPO)
|
| 901 |
-
with torch.no_grad():
|
| 902 |
-
pred_base = decode_full_video(
|
| 903 |
-
q_init, image_embeds_b, sam_model,
|
| 904 |
-
resize_b, orgsize_b, model_dtype,
|
| 905 |
-
).unsqueeze(0) # [1, T, H, W]
|
| 906 |
-
iou_base = utility.mask_iou(pred_base, gt_seg)
|
| 907 |
-
f_base = utility.Eval_Fmeasure(pred_base, gt_seg, None)
|
| 908 |
-
|
| 909 |
-
# LTPO (s1) — also computes r_avt inside q_ltpo_autograd
|
| 910 |
-
reset_q_ltpo_stats()
|
| 911 |
-
q_best = q_ltpo_autograd(
|
| 912 |
-
q_init, image_embeds_b, anchor_indices,
|
| 913 |
-
sam_model, model_dtype, ltpo_cfg,
|
| 914 |
-
)
|
| 915 |
-
stat = get_q_ltpo_stats()[0]
|
| 916 |
-
|
| 917 |
-
with torch.no_grad():
|
| 918 |
-
pred_ltpo = decode_full_video(
|
| 919 |
-
q_best, image_embeds_b, sam_model,
|
| 920 |
-
resize_b, orgsize_b, model_dtype,
|
| 921 |
-
).unsqueeze(0)
|
| 922 |
-
iou_ltpo = utility.mask_iou(pred_ltpo, gt_seg)
|
| 923 |
-
f_ltpo = utility.Eval_Fmeasure(pred_ltpo, gt_seg, None)
|
| 924 |
-
|
| 925 |
-
per_sample.append({
|
| 926 |
-
"reward_gain": stat["reward_gain"],
|
| 927 |
-
"r_avt_gain": stat.get("r_avt_gain", 0.0),
|
| 928 |
-
"r_avt_c_gain": stat.get("r_avt_c_gain", 0.0),
|
| 929 |
-
"e0": stat["e0"],
|
| 930 |
-
"accepted": stat["accepted"],
|
| 931 |
-
"delta_miou": float(iou_ltpo - iou_base),
|
| 932 |
-
"delta_f": float(f_ltpo - f_base),
|
| 933 |
-
})
|
| 934 |
-
|
| 935 |
-
_print_correlation_report(per_sample)
|
| 936 |
-
|
| 937 |
-
# ── Stage 0: gradient connectivity check ─────────────────────────────
|
| 938 |
-
# Loads one image_embed directly from disk — no dataloader, no gt_mask,
|
| 939 |
-
# no media frames required. F_init is a unit-scale random vector that
|
| 940 |
-
# mimics the distribution of Fseg (SAM prompt embeddings are in ℝ^256
|
| 941 |
-
# with per-dim std ≈ 0.05–0.3; we use std=0.1 as a neutral initialisation).
|
| 942 |
-
def run_stage0_check():
|
| 943 |
-
import glob
|
| 944 |
-
sam_model = get_sam_model(model)
|
| 945 |
-
model_dtype = torch.bfloat16
|
| 946 |
-
|
| 947 |
-
embed_files = sorted(glob.glob(os.path.join(args.data_dir, "image_embed", "*.pt")))
|
| 948 |
-
if not embed_files:
|
| 949 |
-
print("[Stage 0] ERROR: no .pt files found in data/image_embed/")
|
| 950 |
-
return False
|
| 951 |
-
|
| 952 |
-
img_embs = torch.load(embed_files[0], map_location="cuda") # [T, 256, 64, 64]
|
| 953 |
-
if img_embs.dim() == 3: # [256,64,64] → [1,256,64,64]
|
| 954 |
-
img_embs = img_embs.unsqueeze(0)
|
| 955 |
-
|
| 956 |
-
torch.manual_seed(42)
|
| 957 |
-
F_init = torch.randn(1, 256, device="cuda") * 0.1 # [1, 256] float32
|
| 958 |
-
|
| 959 |
-
anchors = get_anchor_indices(img_embs.shape[0], 4)
|
| 960 |
-
diag = check_grad_connectivity(F_init, img_embs, anchors, sam_model, model_dtype)
|
| 961 |
-
print("\n[Stage 0] Gradient connectivity check:")
|
| 962 |
-
print(f" file used : {os.path.basename(embed_files[0])}")
|
| 963 |
-
print(f" gradient_connected : {diag['gradient_connected']}")
|
| 964 |
-
print(f" grad_norm (step 0) : {diag['grad_norm_step0']:.6f}")
|
| 965 |
-
print(f" reward trajectory : {[f'{r:.4f}' for r in diag['reward_trajectory']]}")
|
| 966 |
-
return diag["gradient_connected"]
|
| 967 |
-
|
| 968 |
-
# ── Bypass equivalence test ───────────────────────────────────────────
|
| 969 |
-
# Three controlled tests to verify that fseg.unsqueeze(1) (bypass) is
|
| 970 |
-
# numerically equivalent to prompt_encoder(text_embeds=fseg.unsqueeze(1)):
|
| 971 |
-
# Test 1 — dense_emb dtype: dense_A.to(bfloat16) vs dense_emb_bf16 (exact 0?)
|
| 972 |
-
# Test 2 — matched-prec anchor decode: same decoder, same inputs, both bfloat16
|
| 973 |
-
# Test 3 — full-video (all T frames) matched-prec decode
|
| 974 |
-
# If all pass, delta_bypass_init = 0 and the +4.22% is purely from optimization.
|
| 975 |
-
def run_bypass_test():
|
| 976 |
-
from seg_ltpo import _precompute_dense_emb
|
| 977 |
-
|
| 978 |
-
sam_model = get_sam_model(model)
|
| 979 |
-
pe = sam_model.prompt_encoder
|
| 980 |
-
mask_dec = sam_model.mask_decoder
|
| 981 |
-
model_dtype = torch.bfloat16
|
| 982 |
-
|
| 983 |
-
# Get one real Fseg via a standard forward pass on the first batch
|
| 984 |
-
batch = next(iter(_dataloader))
|
| 985 |
-
input_dict = dict_to_cuda(batch)
|
| 986 |
-
with torch.cuda.amp.autocast(dtype=torch.bfloat16, enabled=True):
|
| 987 |
-
with torch.no_grad():
|
| 988 |
-
output_dict = model.forward(
|
| 989 |
-
images=input_dict["images"],
|
| 990 |
-
images_clip=input_dict["images_clip"],
|
| 991 |
-
audio_features=input_dict["audio_feats"],
|
| 992 |
-
image_features=input_dict["image_feats"],
|
| 993 |
-
input_ids=input_dict["input_ids"],
|
| 994 |
-
labels=input_dict["labels"],
|
| 995 |
-
attention_masks=input_dict["attention_masks"],
|
| 996 |
-
masks_list=input_dict["masks"],
|
| 997 |
-
resize_list=input_dict["resizes"],
|
| 998 |
-
orgsize_list=input_dict["orgsizes"],
|
| 999 |
-
conversation_list=input_dict["convs"],
|
| 1000 |
-
refs_num=input_dict["refs_num"],
|
| 1001 |
-
fids=input_dict["fids"],
|
| 1002 |
-
vids=input_dict["vids"],
|
| 1003 |
-
contrast=args.ct_weight,
|
| 1004 |
-
ref_ids=input_dict["ref_ids"],
|
| 1005 |
-
inference=True,
|
| 1006 |
-
)
|
| 1007 |
|
| 1008 |
-
|
| 1009 |
-
|
| 1010 |
-
device = fseg.device
|
| 1011 |
-
|
| 1012 |
-
anchor_indices = get_anchor_indices(image_embeds.shape[0], 4)
|
| 1013 |
-
img_anc = image_embeds[anchor_indices] # [A,256,64,64] float32
|
| 1014 |
-
dense_emb_bf16 = _precompute_dense_emb(sam_model, model_dtype, device) # [1,256,64,64] bfloat16
|
| 1015 |
-
dense_pe = pe.get_dense_pe().to(device) # float32
|
| 1016 |
-
|
| 1017 |
-
def _decode(img, sparse_emb, dense_emb):
|
| 1018 |
-
return mask_dec(
|
| 1019 |
-
image_embeddings=img,
|
| 1020 |
-
image_pe=dense_pe,
|
| 1021 |
-
sparse_prompt_embeddings=sparse_emb,
|
| 1022 |
-
dense_prompt_embeddings=dense_emb,
|
| 1023 |
-
multimask_output=False,
|
| 1024 |
-
)
|
| 1025 |
-
|
| 1026 |
-
def _check(label, tensor_a, tensor_b, exact=False):
|
| 1027 |
-
err = (tensor_a.float() - tensor_b.float()).abs().max().item()
|
| 1028 |
-
tol = 0.0 if exact else 1e-4
|
| 1029 |
-
status = "PASS" if err <= tol else "FAIL"
|
| 1030 |
-
print(f" [{status}] {label:50s} max|A-B| = {err:.2e}")
|
| 1031 |
-
return err <= tol
|
| 1032 |
-
|
| 1033 |
-
print(f"\n[Bypass Test] fseg dtype={fseg.dtype} norm={fseg.float().norm().item():.4f}")
|
| 1034 |
-
|
| 1035 |
-
with torch.no_grad():
|
| 1036 |
-
# Get prompt_encoder outputs (called outside autocast → float32)
|
| 1037 |
-
sparse_A, dense_A = pe(points=None, boxes=None, masks=None,
|
| 1038 |
-
text_embeds=fseg.unsqueeze(1))
|
| 1039 |
-
sparse_B = fseg.unsqueeze(1) # bypass sparse: identical tensor
|
| 1040 |
-
|
| 1041 |
-
# ── Test 1: dense_emb dtype artifact ──────────���─────────────────────
|
| 1042 |
-
# Hypothesis: dense_A (float32) and dense_emb_bf16 differ only because
|
| 1043 |
-
# no_mask_embed.weight is float32; casting to bfloat16 should give exact 0.
|
| 1044 |
-
print("\n [Test 1] dense_emb dtype artifact (expected: exact 0)")
|
| 1045 |
-
t1 = _check("dense_A.to(bfloat16) vs dense_emb_bf16",
|
| 1046 |
-
dense_A.to(torch.bfloat16), dense_emb_bf16, exact=True)
|
| 1047 |
-
|
| 1048 |
-
# ── Test 2: matched-precision decode on anchors ──────────────────────
|
| 1049 |
-
# Both paths use bfloat16 sparse + bfloat16 dense.
|
| 1050 |
-
# If sparse_emb is identical and dense_emb is identical (per Test 1),
|
| 1051 |
-
# masks and iou_preds must be identical (same decoder, same inputs).
|
| 1052 |
-
print("\n [Test 2] matched-precision anchor decode (expected: exact 0)")
|
| 1053 |
-
dense_A_bf16 = dense_A.to(model_dtype)
|
| 1054 |
-
masks_A, iou_A = _decode(img_anc, sparse_A, dense_A_bf16)
|
| 1055 |
-
masks_B, iou_B = _decode(img_anc, sparse_B, dense_emb_bf16)
|
| 1056 |
-
_check("sparse_emb", sparse_A, sparse_B, exact=True)
|
| 1057 |
-
t2m = _check("masks (anchors, matched prec)", masks_A, masks_B, exact=True)
|
| 1058 |
-
t2i = _check("iou_preds (anchors, matched prec)", iou_A, iou_B, exact=True)
|
| 1059 |
-
t2 = t2m and t2i
|
| 1060 |
-
|
| 1061 |
-
# ── Test 3: full-video bypass-init baseline (all T frames) ──────────
|
| 1062 |
-
# Extend Test 2 to all T frames; quantifies delta_bypass_init over
|
| 1063 |
-
# the complete video rather than just the 4 anchor frames.
|
| 1064 |
-
print(f"\n [Test 3] full-video matched-precision decode (T={image_embeds.shape[0]} frames)")
|
| 1065 |
-
masks_full_A, _ = _decode(image_embeds, sparse_A, dense_A_bf16)
|
| 1066 |
-
masks_full_B, _ = _decode(image_embeds, sparse_B, dense_emb_bf16)
|
| 1067 |
-
t3 = _check("masks (all frames, matched prec)", masks_full_A, masks_full_B, exact=True)
|
| 1068 |
-
|
| 1069 |
-
print("\n ── Verdict ──────────────────────────────────────────────────────")
|
| 1070 |
-
if t1 and t2 and t3:
|
| 1071 |
-
print(" ALL PASS — bypass is algebraically and numerically equivalent to")
|
| 1072 |
-
print(" prompt_encoder path under matched precision. delta_bypass_init = 0.")
|
| 1073 |
-
print(" The +4.22% mIoU improvement is purely from q-LTPO optimization.")
|
| 1074 |
-
else:
|
| 1075 |
-
failures = []
|
| 1076 |
-
if not t1: failures.append("Test 1 (dense dtype)")
|
| 1077 |
-
if not t2: failures.append("Test 2 (anchor decode)")
|
| 1078 |
-
if not t3: failures.append("Test 3 (full-video decode)")
|
| 1079 |
-
print(f" FAIL in: {', '.join(failures)}")
|
| 1080 |
-
print(" delta_bypass_init ≠ 0; need per-sample mIoU comparison to quantify.")
|
| 1081 |
-
|
| 1082 |
-
# ── Run evaluation ────────────────────────────────────────────────────
|
| 1083 |
-
|
| 1084 |
-
ltpo_cfg = LTPOConfig()
|
| 1085 |
-
q_ltpo_cfg_s1 = QLTPOConfig(stage=1)
|
| 1086 |
-
q_ltpo_cfg_s2 = QLTPOConfig(stage=2)
|
| 1087 |
-
q_ltpo_cfg_s21 = QLTPOConfig(stage=21) # P1a: tether probe
|
| 1088 |
-
q_ltpo_cfg_s22 = QLTPOConfig(stage=22) # P1b: faithful ext-ref
|
| 1089 |
-
|
| 1090 |
-
# ── Direction B: boundary precision probes ──────────────────────────────
|
| 1091 |
-
q_ltpo_cfg_b1_w03 = QLTPOConfig(stage=1, lambda_area_inc=0.3, area_inc_tau=0.0)
|
| 1092 |
-
q_ltpo_cfg_b1_w10 = QLTPOConfig(stage=1, lambda_area_inc=1.0, area_inc_tau=0.0)
|
| 1093 |
-
|
| 1094 |
-
# ── Direction II: Frame-adaptive token optimization ─────────────────────
|
| 1095 |
-
# fa_c03: delta clipped at 0.3×‖q_init‖ — moderate constraint.
|
| 1096 |
-
# First probe to answer: "does constrained frame-adaptive beat shared q?"
|
| 1097 |
-
# If yes → ablate tighter/looser constraints and smoothness in follow-up.
|
| 1098 |
-
q_ltpo_cfg_fa_c03 = QLTPOConfig(stage=1, lambda_residual=0.001, lambda_smooth_temp=0.0, max_delta_drift_scale=0.3)
|
| 1099 |
-
|
| 1100 |
-
max_rows = args.max_eval_rows # -1 = all rows
|
| 1101 |
-
|
| 1102 |
-
# --max_eval_rows 0 → Stage 0 + bypass equivalence check, then exit
|
| 1103 |
-
if max_rows == 0:
|
| 1104 |
-
run_stage0_check()
|
| 1105 |
-
run_bypass_test()
|
| 1106 |
-
elif _split == 'test_n':
|
| 1107 |
-
# Null safety check: baseline + Stage 1 + frame-adaptive
|
| 1108 |
-
valuate_Null(model, _dataloader, max_rows=max_rows)
|
| 1109 |
-
for cfg_name, cfg in [("s1", q_ltpo_cfg_s1)]:
|
| 1110 |
-
reset_q_ltpo_stats()
|
| 1111 |
-
valuate_ltpo_null(model, _dataloader, cfg,
|
| 1112 |
-
optimize_fn=q_ltpo_autograd, max_rows=max_rows)
|
| 1113 |
-
print_q_ltpo_stats(f"null_q_ltpo_{cfg_name}")
|
| 1114 |
-
reset_q_ltpo_stats()
|
| 1115 |
-
valuate_ltpo_adaptive(model, _dataloader, "null_fa_c03",
|
| 1116 |
-
q_ltpo_cfg_fa_c03, max_rows=max_rows)
|
| 1117 |
-
print_q_ltpo_stats("null_fa_c03")
|
| 1118 |
else:
|
| 1119 |
-
valuate(model,
|
| 1120 |
-
# Step A0: reward–metric correlation study (s1 + AVT proxy signals)
|
| 1121 |
-
valuate_ltpo_correlation_study(
|
| 1122 |
-
model, _dataloader, q_ltpo_cfg_s1, max_rows=max_rows
|
| 1123 |
-
)
|
| 1124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import transformers
|
| 2 |
|
| 3 |
from torch.cuda.amp import autocast, GradScaler
|
|
|
|
| 208 |
|
| 209 |
import torch.multiprocessing as mp
|
| 210 |
if __name__ == "__main__":
|
| 211 |
+
mp.set_start_method("spawn")
|
| 212 |
set_seed(42)
|
| 213 |
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 214 |
args.mllm,
|
|
|
|
| 224 |
print("seg_token_idx: ", seg_token_idx)
|
| 225 |
|
| 226 |
|
| 227 |
+
if args.eval_split not in {"test_s", "test_u", "test_n"}:
|
| 228 |
+
raise ValueError(f"Unsupported eval_split: {args.eval_split}")
|
| 229 |
+
|
| 230 |
+
val_dataset = REFAVS(args.eval_split, args, tokenizer, input_type='refer')
|
| 231 |
+
val_dataloader = DataLoader(
|
| 232 |
+
val_dataset,
|
| 233 |
+
batch_size=1,
|
| 234 |
+
shuffle=False,
|
| 235 |
+
num_workers=4,
|
| 236 |
+
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 237 |
+
)
|
| 238 |
|
| 239 |
|
| 240 |
|
|
|
|
| 340 |
model = model.to("cuda")
|
| 341 |
model.resize_token_embeddings(len(tokenizer))
|
| 342 |
|
| 343 |
+
missing, unexpected = model.load_state_dict(
|
| 344 |
+
torch.load(args.saved_model, map_location="cpu"),
|
| 345 |
+
strict=False,
|
| 346 |
+
)
|
| 347 |
+
print(f"saved model loaded: {args.saved_model}")
|
| 348 |
+
print(f"missing keys: {len(missing)} | unexpected keys: {len(unexpected)}")
|
| 349 |
|
| 350 |
|
| 351 |
save_root = args.visualization_root
|
|
|
|
| 404 |
print("visualization finished")
|
| 405 |
|
| 406 |
|
| 407 |
+
def valuate(model, dataloader, name):
|
| 408 |
model.eval()
|
| 409 |
|
| 410 |
total_iou = 0
|
| 411 |
total_fscore = 0
|
| 412 |
count = 0
|
| 413 |
|
| 414 |
+
for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating on {name}")):
|
| 415 |
+
if args.max_eval_rows > 0 and batch_idx >= args.max_eval_rows:
|
|
|
|
| 416 |
break
|
| 417 |
input_dict = dict_to_cuda(batch)
|
| 418 |
|
|
|
|
| 447 |
total_fscore += fscore * num_seg * T
|
| 448 |
count += num_seg * T
|
| 449 |
|
| 450 |
+
if count == 0:
|
| 451 |
+
raise RuntimeError(f"No samples were evaluated for {name}")
|
| 452 |
+
|
| 453 |
print(f"\n valuate on {name}: miou: {total_iou/count} fscore: {total_fscore/count}")
|
| 454 |
|
| 455 |
|
| 456 |
+
def valuate_Null(model, dataloader):
|
| 457 |
model.eval()
|
| 458 |
|
| 459 |
total_metric = 0
|
| 460 |
count = 0
|
| 461 |
|
| 462 |
+
for batch_idx, batch in enumerate(tqdm(dataloader, desc=f"Evaluating on Null")):
|
| 463 |
+
if args.max_eval_rows > 0 and batch_idx >= args.max_eval_rows:
|
|
|
|
| 464 |
break
|
| 465 |
input_dict = dict_to_cuda(batch)
|
| 466 |
+
with torch.no_grad():
|
| 467 |
+
output_dict = model.forward(images=input_dict["images"],
|
| 468 |
+
images_clip=input_dict["images_clip"],
|
| 469 |
+
audio_features=input_dict["audio_feats"],
|
| 470 |
+
image_features=input_dict["image_feats"],
|
| 471 |
+
input_ids=input_dict["input_ids"],
|
| 472 |
+
labels=input_dict["labels"],
|
| 473 |
+
attention_masks=input_dict["attention_masks"],
|
| 474 |
+
masks_list=input_dict["masks"],
|
| 475 |
+
resize_list=input_dict["resizes"],
|
| 476 |
+
orgsize_list=input_dict["orgsizes"],
|
| 477 |
+
conversation_list=input_dict["convs"],
|
| 478 |
+
refs_num=input_dict["refs_num"],
|
| 479 |
+
fids=input_dict["fids"],
|
| 480 |
+
vids=input_dict["vids"],
|
| 481 |
+
contrast=args.ct_weight,
|
| 482 |
+
ref_ids=input_dict["ref_ids"],
|
| 483 |
+
inference=True)
|
|
|
|
| 484 |
pred_masks = output_dict["pred_masks"] # list[B]:[num_seg, T, H, W]
|
| 485 |
gt_masks = output_dict["gt_masks"] # list[B]:[num_seg, T, H, W]
|
| 486 |
for i in range(len(pred_masks)):
|
|
|
|
| 491 |
total_metric += null_metric * num_seg * T
|
| 492 |
count += num_seg * T
|
| 493 |
|
| 494 |
+
if count == 0:
|
| 495 |
+
raise RuntimeError("No samples were evaluated for test_n")
|
|
|
|
| 496 |
|
| 497 |
+
print(f"\n valuate on test_n_refer, metric: {total_metric / count}")
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
+
if args.eval_split == "test_n":
|
| 501 |
+
valuate_Null(model, val_dataloader)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 502 |
else:
|
| 503 |
+
valuate(model, val_dataloader, args.eval_split)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
save_audio_feats.py
CHANGED
|
@@ -80,4 +80,3 @@ for vid in vids:
|
|
| 80 |
# print(f"{vid}: {audio_embed.shape}")
|
| 81 |
torch.save(audio_embed, f'{save_dir}/{vid}.pt')
|
| 82 |
print(f'{vid} embedding saved {audio_embed.shape}')
|
| 83 |
-
|
|
|
|
| 80 |
# print(f"{vid}: {audio_embed.shape}")
|
| 81 |
torch.save(audio_embed, f'{save_dir}/{vid}.pt')
|
| 82 |
print(f'{vid} embedding saved {audio_embed.shape}')
|
|
|
setup_simtoken.md
CHANGED
|
@@ -1,12 +1,22 @@
|
|
| 1 |
# SimToken Setup
|
| 2 |
|
|
|
|
|
|
|
| 3 |
---
|
| 4 |
|
| 5 |
## 1. Create Environment
|
| 6 |
|
|
|
|
|
|
|
| 7 |
```bash
|
| 8 |
-
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
python -m pip install --upgrade pip wheel "setuptools<81"
|
| 12 |
|
|
@@ -34,17 +44,63 @@ pip install \
|
|
| 34 |
huggingface_hub
|
| 35 |
```
|
| 36 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 37 |
---
|
| 38 |
|
| 39 |
-
## 2.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 42 |
|
| 43 |
```bash
|
| 44 |
huggingface-cli login
|
| 45 |
```
|
| 46 |
|
| 47 |
-
下载完整 repo
|
| 48 |
|
| 49 |
```bash
|
| 50 |
mkdir -p /workspace/SimToken
|
|
@@ -56,97 +112,108 @@ huggingface-cli download yfan07/SimToken \
|
|
| 56 |
--local-dir-use-symlinks False
|
| 57 |
```
|
| 58 |
|
| 59 |
-
下载完成后解压数据
|
| 60 |
|
| 61 |
```bash
|
| 62 |
cd /workspace/SimToken/data
|
| 63 |
|
| 64 |
-
tar -xf image_embed.tar
|
| 65 |
tar -xzf gt_mask.tar.gz
|
| 66 |
tar -xzf audio_embed.tar.gz
|
| 67 |
tar -xf media.tar
|
| 68 |
```
|
| 69 |
|
| 70 |
-
|
| 71 |
---
|
| 72 |
|
| 73 |
-
##
|
| 74 |
|
| 75 |
-
`transformers==4.30.2` 与新版 `huggingface_hub` 存在
|
| 76 |
-
解决方案:先用 CLI 将模型下载到本地缓存,之后运行实验时加 `TRANSFORMERS_OFFLINE=1`,跳过所有网络请求。
|
| 77 |
|
| 78 |
```bash
|
| 79 |
-
# Chat-UniVi-7B
|
| 80 |
huggingface-cli download Chat-UniVi/Chat-UniVi-7B-v1.5
|
| 81 |
|
| 82 |
-
# CLIP ViT-L
|
| 83 |
huggingface-cli download openai/clip-vit-large-patch14
|
| 84 |
```
|
| 85 |
|
| 86 |
-
下载完成后
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
|
| 88 |
---
|
| 89 |
|
| 90 |
-
##
|
| 91 |
|
| 92 |
-
|
| 93 |
|
| 94 |
```bash
|
| 95 |
cd /workspace/SimToken
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_s
|
| 102 |
|
| 103 |
-
|
| 104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
-
|
| 107 |
-
TRANSFORMERS_OFFLINE=1 python -W ignore load_model.py --eval_split test_u --max_eval_rows 50
|
| 108 |
|
| 109 |
-
|
| 110 |
-
|
|
|
|
| 111 |
```
|
| 112 |
|
| 113 |
-
每次评估依次输出:Baseline + q-LTPO Stage 1 两组结果及诊断统计。
|
| 114 |
-
|
| 115 |
---
|
| 116 |
|
| 117 |
-
##
|
| 118 |
-
|
| 119 |
-
数据目录以压缩包形式存储,可大幅减少文件数量,避免 HuggingFace commit 频率限制。
|
| 120 |
|
| 121 |
-
|
| 122 |
|
| 123 |
```bash
|
| 124 |
cd /workspace/SimToken/data
|
| 125 |
|
| 126 |
-
tar -cf image_embed.tar image_embed/
|
| 127 |
tar -czf gt_mask.tar.gz gt_mask/
|
| 128 |
tar -czf audio_embed.tar.gz audio_embed/
|
| 129 |
tar -cf media.tar media/
|
| 130 |
|
| 131 |
-
# 确认压缩包存在后删除原始目录
|
| 132 |
ls -lh *.tar*
|
| 133 |
rm -rf image_embed/ gt_mask/ audio_embed/ media/
|
| 134 |
```
|
| 135 |
|
| 136 |
-
|
| 137 |
|
| 138 |
```bash
|
| 139 |
-
|
| 140 |
-
find /workspace/SimToken -name "*.pyc" -delete
|
| 141 |
|
| 142 |
-
|
|
|
|
|
|
|
|
|
|
| 143 |
|
| 144 |
-
|
| 145 |
python upload_hf.py --repo yfan07/SimToken
|
| 146 |
```
|
| 147 |
-
|
| 148 |
-
**注意事项:**
|
| 149 |
-
- 建议在 `tmux` 里运行,防止 SSH 断开:`tmux new -s upload`,完成后 `Ctrl+B D` detach
|
| 150 |
-
- 支持断点续传:中断后重新执行同一命令会自动跳过已上传文件
|
| 151 |
-
- 遇到 rate limit(HTTP 429)时脚本会自动等待约 1 小时后重试
|
| 152 |
-
- 监控进度:`tail -f /workspace/SimToken/upload.log`
|
|
|
|
| 1 |
# SimToken Setup
|
| 2 |
|
| 3 |
+
本文档用于在新机器上重建 SimToken 环境,并准备后续 A-min 实验。
|
| 4 |
+
|
| 5 |
---
|
| 6 |
|
| 7 |
## 1. Create Environment
|
| 8 |
|
| 9 |
+
先确认 GPU 和 CUDA driver 状态:
|
| 10 |
+
|
| 11 |
```bash
|
| 12 |
+
nvidia-smi
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
创建 conda 环境:
|
| 16 |
+
|
| 17 |
+
```bash
|
| 18 |
+
/opt/miniforge3/condabin/conda create -n simtoken python=3.10 -y
|
| 19 |
+
/opt/miniforge3/condabin/conda activate simtoken
|
| 20 |
|
| 21 |
python -m pip install --upgrade pip wheel "setuptools<81"
|
| 22 |
|
|
|
|
| 44 |
huggingface_hub
|
| 45 |
```
|
| 46 |
|
| 47 |
+
快速验证:
|
| 48 |
+
|
| 49 |
+
```bash
|
| 50 |
+
python - <<'PY'
|
| 51 |
+
import torch
|
| 52 |
+
print("torch:", torch.__version__)
|
| 53 |
+
print("cuda available:", torch.cuda.is_available())
|
| 54 |
+
print("device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "cpu")
|
| 55 |
+
PY
|
| 56 |
+
```
|
| 57 |
+
|
| 58 |
---
|
| 59 |
|
| 60 |
+
## 2. Check Workspace After Migration
|
| 61 |
+
|
| 62 |
+
使用服务器平台的迁移工具完成目录迁移后,在新机器上确认关键文件:
|
| 63 |
+
|
| 64 |
+
```bash
|
| 65 |
+
cd /workspace/SimToken
|
| 66 |
+
|
| 67 |
+
ls -lh checkpoints/simtoken_pretrained.pth
|
| 68 |
+
ls -lh models/segment_anything/sam_vit_h_4b8939.pth
|
| 69 |
+
ls -d data/image_embed data/gt_mask data/audio_embed data/media
|
| 70 |
+
```
|
| 71 |
+
|
| 72 |
+
如果迁移后只有压缩包而没有解压目录,重新解压:
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
cd /workspace/SimToken/data
|
| 76 |
+
|
| 77 |
+
tar -xf image_embed.tar
|
| 78 |
+
tar -xzf gt_mask.tar.gz
|
| 79 |
+
tar -xzf audio_embed.tar.gz
|
| 80 |
+
tar -xf media.tar
|
| 81 |
+
```
|
| 82 |
|
| 83 |
+
清理迁移中不需要的缓存:
|
| 84 |
+
|
| 85 |
+
```bash
|
| 86 |
+
cd /workspace/SimToken
|
| 87 |
+
find . -name "__pycache__" -prune -exec rm -rf {} +
|
| 88 |
+
find . -name ".pytest_cache" -prune -exec rm -rf {} +
|
| 89 |
+
find . -name ".cache" -prune -exec rm -rf {} +
|
| 90 |
+
find . -name "*.pyc" -delete
|
| 91 |
+
```
|
| 92 |
+
|
| 93 |
+
---
|
| 94 |
+
|
| 95 |
+
## 3. Download from HuggingFace
|
| 96 |
+
|
| 97 |
+
如果新机器不使用迁移工具,而是从 HuggingFace 重新初始化,先登录:
|
| 98 |
|
| 99 |
```bash
|
| 100 |
huggingface-cli login
|
| 101 |
```
|
| 102 |
|
| 103 |
+
下载完整 repo:
|
| 104 |
|
| 105 |
```bash
|
| 106 |
mkdir -p /workspace/SimToken
|
|
|
|
| 112 |
--local-dir-use-symlinks False
|
| 113 |
```
|
| 114 |
|
| 115 |
+
下载完成后解压数据:
|
| 116 |
|
| 117 |
```bash
|
| 118 |
cd /workspace/SimToken/data
|
| 119 |
|
| 120 |
+
tar -xf image_embed.tar
|
| 121 |
tar -xzf gt_mask.tar.gz
|
| 122 |
tar -xzf audio_embed.tar.gz
|
| 123 |
tar -xf media.tar
|
| 124 |
```
|
| 125 |
|
|
|
|
| 126 |
---
|
| 127 |
|
| 128 |
+
## 4. Pre-download Model Weights
|
| 129 |
|
| 130 |
+
`transformers==4.30.2` 与新版 `huggingface_hub` 可能存在网络/API 兼容问题。建议先用 CLI 将模型下载到本地缓存,实验时再加 `TRANSFORMERS_OFFLINE=1`。
|
|
|
|
| 131 |
|
| 132 |
```bash
|
| 133 |
+
# Chat-UniVi-7B
|
| 134 |
huggingface-cli download Chat-UniVi/Chat-UniVi-7B-v1.5
|
| 135 |
|
| 136 |
+
# CLIP ViT-L
|
| 137 |
huggingface-cli download openai/clip-vit-large-patch14
|
| 138 |
```
|
| 139 |
|
| 140 |
+
下载完成后做离线验证:
|
| 141 |
+
|
| 142 |
+
```bash
|
| 143 |
+
cd /workspace/SimToken
|
| 144 |
+
|
| 145 |
+
TRANSFORMERS_OFFLINE=1 /opt/miniforge3/condabin/conda run -n simtoken \
|
| 146 |
+
python -m py_compile train.py load_model.py decoder_invariance_check.py
|
| 147 |
+
```
|
| 148 |
|
| 149 |
---
|
| 150 |
|
| 151 |
+
## 5. Smoke Test
|
| 152 |
|
| 153 |
+
先跑一个轻量 sanity check,确认 checkpoint、数据和离线模型缓存都能正常读取:
|
| 154 |
|
| 155 |
```bash
|
| 156 |
cd /workspace/SimToken
|
| 157 |
|
| 158 |
+
TRANSFORMERS_OFFLINE=1 /opt/miniforge3/condabin/conda run -n simtoken \
|
| 159 |
+
python decoder_invariance_check.py \
|
| 160 |
+
--eval_split test_s \
|
| 161 |
+
--max_eval_rows 1
|
| 162 |
+
```
|
| 163 |
|
| 164 |
+
如果可以正常加载模型并输出 per-frame diff,就可以启动完整 A-min 训练:
|
|
|
|
| 165 |
|
| 166 |
+
```bash
|
| 167 |
+
cd /workspace/SimToken
|
| 168 |
+
mkdir -p log checkpoints
|
| 169 |
+
|
| 170 |
+
TRANSFORMERS_OFFLINE=1 /opt/miniforge3/condabin/conda run -n simtoken \
|
| 171 |
+
python -W ignore train.py \
|
| 172 |
+
--name amin_full_e1 \
|
| 173 |
+
--init_from_saved_model \
|
| 174 |
+
--epochs 1 \
|
| 175 |
+
--batch_size 2 \
|
| 176 |
+
--lr 1e-4 \
|
| 177 |
+
--saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
|
| 178 |
+
--log_root /workspace/SimToken/log \
|
| 179 |
+
--checkpoint_root /workspace/SimToken/checkpoints
|
| 180 |
+
```
|
| 181 |
|
| 182 |
+
启动日志中应出现:
|
|
|
|
| 183 |
|
| 184 |
+
```text
|
| 185 |
+
initialized training from saved model: /workspace/SimToken/checkpoints/simtoken_pretrained.pth
|
| 186 |
+
missing keys: ... | unexpected keys: ...
|
| 187 |
```
|
| 188 |
|
|
|
|
|
|
|
| 189 |
---
|
| 190 |
|
| 191 |
+
## 6. Upload to HuggingFace
|
|
|
|
|
|
|
| 192 |
|
| 193 |
+
实验结束后,如需重新上传到 HuggingFace,先将数据目录压缩为归档文件,减少文件数量:
|
| 194 |
|
| 195 |
```bash
|
| 196 |
cd /workspace/SimToken/data
|
| 197 |
|
| 198 |
+
tar -cf image_embed.tar image_embed/
|
| 199 |
tar -czf gt_mask.tar.gz gt_mask/
|
| 200 |
tar -czf audio_embed.tar.gz audio_embed/
|
| 201 |
tar -cf media.tar media/
|
| 202 |
|
|
|
|
| 203 |
ls -lh *.tar*
|
| 204 |
rm -rf image_embed/ gt_mask/ audio_embed/ media/
|
| 205 |
```
|
| 206 |
|
| 207 |
+
清理缓存并上传:
|
| 208 |
|
| 209 |
```bash
|
| 210 |
+
cd /workspace/SimToken
|
|
|
|
| 211 |
|
| 212 |
+
find . -name "__pycache__" -prune -exec rm -rf {} +
|
| 213 |
+
find . -name ".pytest_cache" -prune -exec rm -rf {} +
|
| 214 |
+
find . -name ".cache" -prune -exec rm -rf {} +
|
| 215 |
+
find . -name "*.pyc" -delete
|
| 216 |
|
| 217 |
+
huggingface-cli login
|
| 218 |
python upload_hf.py --repo yfan07/SimToken
|
| 219 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
simtoken_experiment.md
ADDED
|
@@ -0,0 +1,369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SimToken 实验路线文档
|
| 2 |
+
|
| 3 |
+
## 0. 当前状态
|
| 4 |
+
|
| 5 |
+
前置诊断已经完成,路线收敛到 **A-min dynamic referent gate training**。
|
| 6 |
+
|
| 7 |
+
已确认结论:
|
| 8 |
+
|
| 9 |
+
1. **SAM decoder 下游是逐帧 batch-parallel 解码**
|
| 10 |
+
`mask_decoder(image_embeddings[0:T])[t]` 与 `mask_decoder(image_embeddings[t:t+1])[0]` 只有混合精度数值噪声差异。旧的 decoder-level joint-frame competition 假设关闭。
|
| 11 |
+
|
| 12 |
+
2. **target_frame sweep 基本无效**
|
| 13 |
+
不同 target frame 生成的 q 几乎相同,`cos_to_q5` 通常在 `0.997+`;Seen/Null 上 oracle gain 约 `+0.0009`。这条 TTO 路线关闭。
|
| 14 |
+
|
| 15 |
+
3. **raw SAM-space D2 失效**
|
| 16 |
+
256 维 `q/Fseg` 与 SAM image embedding 不在可直接 cosine 的语义空间,`real q ≈ shuffled/wrong_ref q`,甚至 random q 更高。该定义关闭。
|
| 17 |
+
|
| 18 |
+
4. **LLM-space D2 有弱诊断信号,但不适合作为主 reward**
|
| 19 |
+
用 4096 维 `[SEG]` hidden state 与 `mm_projector(CLIP patch tokens)` 后的视觉 token 计算 D2,可以得到正相关:
|
| 20 |
+
- `corr(s_pred, frame_iou) ≈ +0.316`
|
| 21 |
+
- bottom 20% `s_pred` 中 failure rate 相比随机 baseline 约 `1.60x`
|
| 22 |
+
- 控制 `iou_pred` / `pred_area` 后偏相关约 `+0.14`
|
| 23 |
+
|
| 24 |
+
结论:`s_pred(beta=1.0)` 可以作为诊断信号或 frame-aware gate 的候选输入,但不能作为核心 TTO reward。
|
| 25 |
+
|
| 26 |
+
5. **margin-D2 无效**
|
| 27 |
+
离线 `s_margin = s(real) - max(s(shuffled), s(wrong_ref))` 的 failure enrichment 约 `0.93x`,会抵消掉有用的通用可见性/质量信号。该路线关闭。
|
| 28 |
+
|
| 29 |
+
当前最干净的解释是:
|
| 30 |
+
|
| 31 |
+
> q 本身通常是稳定的 referent anchor;主要瓶颈不在 q 生成,也不在简单 q selection,而在 SAM decoder 如何使用已有的 `mask_token -> q` sparse self-attention path。
|
| 32 |
+
|
| 33 |
+
2026-04-22 更新:
|
| 34 |
+
|
| 35 |
+
完整训练每个 epoch 约 2-4 小时,瓶颈主要在 7B MLLM forward,而不在 gate 本身。因此当前实验策略已调整为:
|
| 36 |
+
|
| 37 |
+
1. 先缓存固定 checkpoint 下的 `q = seg_embeddings`;
|
| 38 |
+
2. 在 cached q + cached SAM image embeddings 上训练 gate-only;
|
| 39 |
+
3. 用 cached eval split 快速判断 gate 是否有泛化收益;
|
| 40 |
+
4. 只有 gate-only 泛化信号成立后,再跑完整 A-min 联合训练。
|
| 41 |
+
|
| 42 |
+
---
|
| 43 |
+
|
| 44 |
+
## 1. A-min 当前实现
|
| 45 |
+
|
| 46 |
+
已在代码中加入 A-min dynamic referent gate:
|
| 47 |
+
|
| 48 |
+
- 文件:`models/segment_anything/modeling/transformer.py`
|
| 49 |
+
- 模块:`ReferentGate`
|
| 50 |
+
- 插入位置:`TwoWayAttentionBlock` 的 sparse self-attention + `norm1` 之后,token-to-image cross-attention 之前
|
| 51 |
+
- 作用对象:只作用于 `mask_tokens`
|
| 52 |
+
- 不作用于:`iou_token` 和 `q/sparse_prompt` 本身
|
| 53 |
+
|
| 54 |
+
SAM token index:
|
| 55 |
+
|
| 56 |
+
```python
|
| 57 |
+
tokens = [iou_token, mask_tokens..., sparse_prompt(q)]
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
因此:
|
| 61 |
+
|
| 62 |
+
```python
|
| 63 |
+
iou_token index: 0
|
| 64 |
+
mask token range: 1 : 1 + num_mask_tokens
|
| 65 |
+
q token index: 1 + num_mask_tokens
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
A-min gate 形式:
|
| 69 |
+
|
| 70 |
+
```python
|
| 71 |
+
alpha = sigmoid(Linear([mask_token, q, cos(mask_token, q)]))
|
| 72 |
+
mask_token = mask_token + alpha * Linear(q)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
为保证旧 checkpoint 初始行为不变,`proj(q)` 分支使用零初始化。当前也将 `gate` 分支零初始化,使 alpha 有干净观测基线:
|
| 76 |
+
|
| 77 |
+
```python
|
| 78 |
+
nn.init.zeros_(self.gate.weight)
|
| 79 |
+
nn.init.zeros_(self.gate.bias)
|
| 80 |
+
nn.init.zeros_(self.proj.weight)
|
| 81 |
+
nn.init.zeros_(self.proj.bias)
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
初始时 gate 为 identity:
|
| 85 |
+
|
| 86 |
+
```text
|
| 87 |
+
max_abs_diff(gate(mask, q), mask) = 0.0
|
| 88 |
+
alpha_mean = 0.5
|
| 89 |
+
alpha_std = 0.0
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
当前训练 forward 保持完整链路:`prepare_inputs_labels_for_multimodal -> MLLM forward -> text_hidden_fcs -> SAM mask decoder -> loss`。`--gate_only` 只控制参数冻结范围,不再改变 forward 语义。
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
|
| 96 |
+
## 2. 当前新增工具
|
| 97 |
+
|
| 98 |
+
### 2.1 训练脚本增强
|
| 99 |
+
|
| 100 |
+
`train.py` 已加入:
|
| 101 |
+
|
| 102 |
+
- `--max_steps`
|
| 103 |
+
- `--overfit_samples`
|
| 104 |
+
- `--log_gate_stats_every`
|
| 105 |
+
- `--skip_eval_after_train`
|
| 106 |
+
- `--eval_train_only`
|
| 107 |
+
|
| 108 |
+
启动时会打印 referent gate 参数是否 trainable、是否进入 optimizer,以及初始 `proj_norm/gate_norm`。
|
| 109 |
+
|
| 110 |
+
### 2.2 cached q 路线
|
| 111 |
+
|
| 112 |
+
新增脚本:
|
| 113 |
+
|
| 114 |
+
- `cache_q_features.py`
|
| 115 |
+
- 离线缓存 `q = seg_embeddings`
|
| 116 |
+
- cache 文件很小,因为只保存 q 和少量 metadata
|
| 117 |
+
- `image_embeddings` 仍使用已有 `data/image_embed/{vid}.pt`
|
| 118 |
+
- `gt_masks` 仍使用已有 `data/gt_mask/...`
|
| 119 |
+
|
| 120 |
+
- `train_cached_gate.py`
|
| 121 |
+
- 加载 base model 和 cached q
|
| 122 |
+
- 冻结全部参数,只训练 `referent_gate`
|
| 123 |
+
- 支持 `--eval_only`、`--disable_gate`
|
| 124 |
+
- 支持 `--save_gate_only`,只保存 gate 参数,checkpoint 约 1.6MB
|
| 125 |
+
- 支持 `--gate_checkpoint`,在 base checkpoint 上 overlay gate-only checkpoint
|
| 126 |
+
- gate stats 会记录:
|
| 127 |
+
|
| 128 |
+
```text
|
| 129 |
+
batch_miou
|
| 130 |
+
batch_fscore
|
| 131 |
+
proj_norm
|
| 132 |
+
gate_norm
|
| 133 |
+
proj_grad_norm
|
| 134 |
+
gate_grad_norm
|
| 135 |
+
alpha_mean / alpha_std / alpha_min / alpha_max
|
| 136 |
+
```
|
| 137 |
+
|
| 138 |
+
cached 解码已优化:一个 dataloader batch 会展平成 paired frame batch 调用 `mask_decoder.forward_modified_v3`,避免逐 sample 调 decoder 的主要开销,同时不会产生 prompt/image cross product。
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## 3. 已完成实验结果
|
| 143 |
+
|
| 144 |
+
### 3.1 cached identity 与原始 pipeline 一致性
|
| 145 |
+
|
| 146 |
+
先用 `test_s` 前 10 条验证 cached pipeline 是否与原始 `load_model.py` 对齐:
|
| 147 |
+
|
| 148 |
+
```text
|
| 149 |
+
cached identity:
|
| 150 |
+
mIoU = 0.9686462879
|
| 151 |
+
Fscore = 0.9868578851
|
| 152 |
+
|
| 153 |
+
original load_model.py:
|
| 154 |
+
mIoU = 0.9686277151
|
| 155 |
+
Fscore = 0.9868472159
|
| 156 |
+
|
| 157 |
+
diff:
|
| 158 |
+
mIoU = +0.0000186
|
| 159 |
+
Fscore = +0.0000107
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
结论:差异远小于 0.001,cached q pipeline 与原始 eval pipeline 一致,可以用于 gate-only 快速验证。
|
| 163 |
+
|
| 164 |
+
### 3.2 gate probe:梯度路径与 alpha 分化
|
| 165 |
+
|
| 166 |
+
在 cached train128 上跑 50 optimizer steps:
|
| 167 |
+
|
| 168 |
+
```text
|
| 169 |
+
step 5:
|
| 170 |
+
proj_norm=0.074015
|
| 171 |
+
gate_norm=0.064479
|
| 172 |
+
proj_grad_norm=0.052291
|
| 173 |
+
gate_grad_norm=0.000170
|
| 174 |
+
alpha_mean=0.4999
|
| 175 |
+
alpha_std=0.0019
|
| 176 |
+
|
| 177 |
+
step 50:
|
| 178 |
+
proj_norm=0.428711
|
| 179 |
+
gate_norm=0.523223
|
| 180 |
+
proj_grad_norm=0.022453
|
| 181 |
+
gate_grad_norm=0.000504
|
| 182 |
+
alpha_mean=0.5063
|
| 183 |
+
alpha_std=0.0112
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
结论:
|
| 187 |
+
|
| 188 |
+
- `proj_norm` 从 0 稳定增长,注入分支有梯度;
|
| 189 |
+
- `gate_norm` 也开始增长,alpha 控制分支参与学习;
|
| 190 |
+
- `alpha_std` 从 0 增长,说明 gate 对不同输入有分化响应;
|
| 191 |
+
- 计算图、冻结范围、optimizer param groups 均正常。
|
| 192 |
+
|
| 193 |
+
### 3.3 overfit32:表达能力验证
|
| 194 |
+
|
| 195 |
+
cached train32 identity baseline:
|
| 196 |
+
|
| 197 |
+
```text
|
| 198 |
+
mIoU = 0.8814558
|
| 199 |
+
Fscore = 0.9375512
|
| 200 |
+
```
|
| 201 |
+
|
| 202 |
+
cached gate overfit32,200 steps,lr=1e-4:
|
| 203 |
+
|
| 204 |
+
```text
|
| 205 |
+
mIoU = 0.9085821
|
| 206 |
+
Fscore = 0.9444574
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
提升:
|
| 210 |
+
|
| 211 |
+
```text
|
| 212 |
+
mIoU = +0.0271263
|
| 213 |
+
Fscore = +0.0069063
|
| 214 |
+
```
|
| 215 |
+
|
| 216 |
+
结论:在 q、SAM image embeddings、mask decoder 原始参数均固定时,仅训练 A-min gate 就能明显提高训练集 mIoU,说明 gate 机制有表达能力,梯度路径通畅。
|
| 217 |
+
|
| 218 |
+
### 3.4 overfit32 泛化评估
|
| 219 |
+
|
| 220 |
+
对 cached eval split 前 200 条,identity baseline:
|
| 221 |
+
|
| 222 |
+
```text
|
| 223 |
+
test_s mIoU = 0.7390979
|
| 224 |
+
test_s Fscore = 0.8190672
|
| 225 |
+
|
| 226 |
+
test_u mIoU = 0.6732285
|
| 227 |
+
test_u Fscore = 0.7734924
|
| 228 |
+
|
| 229 |
+
test_n metric = 0.0606105
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
overfit32 gate checkpoint:
|
| 233 |
+
|
| 234 |
+
```text
|
| 235 |
+
test_s mIoU = 0.7199481
|
| 236 |
+
test_s Fscore = 0.8045849
|
| 237 |
+
|
| 238 |
+
test_u mIoU = 0.6672303
|
| 239 |
+
test_u Fscore = 0.7663978
|
| 240 |
+
|
| 241 |
+
test_n metric = 0.0648588
|
| 242 |
+
```
|
| 243 |
+
|
| 244 |
+
delta:
|
| 245 |
+
|
| 246 |
+
```text
|
| 247 |
+
test_s mIoU = -0.0191498
|
| 248 |
+
test_s Fscore = -0.0144823
|
| 249 |
+
|
| 250 |
+
test_u mIoU = -0.0059983
|
| 251 |
+
test_u Fscore = -0.0070946
|
| 252 |
+
|
| 253 |
+
test_n metric = +0.0042483
|
| 254 |
+
```
|
| 255 |
+
|
| 256 |
+
结论:
|
| 257 |
+
|
| 258 |
+
- overfit32 gate 没有泛化;
|
| 259 |
+
- Null metric 略升,说明小样本过拟合有轻微放大前景的倾向;
|
| 260 |
+
- 这不是方法失败,而是 32 个样本不足以学到泛化 referent anchoring 的预期结果;
|
| 261 |
+
- 下一步应扩大 cached train 样本量,并降低 lr。
|
| 262 |
+
|
| 263 |
+
---
|
| 264 |
+
|
| 265 |
+
## 4. 当前下一步实验:cached train256 gate-only
|
| 266 |
+
|
| 267 |
+
用户已经完成 train256 的 q 缓存。下一步用 train256 跑更保守的 gate-only 泛化实验。
|
| 268 |
+
|
| 269 |
+
### Step 1:训练 cached gate-only train256
|
| 270 |
+
|
| 271 |
+
```bash
|
| 272 |
+
cd /workspace/SimToken
|
| 273 |
+
mkdir -p log checkpoints
|
| 274 |
+
|
| 275 |
+
TRANSFORMERS_OFFLINE=1 python -u -W ignore train_cached_gate.py \
|
| 276 |
+
--cache_split train \
|
| 277 |
+
--cache_root /workspace/SimToken/cache_q \
|
| 278 |
+
--name cached_gate_train256_s300_lr3e5 \
|
| 279 |
+
--epochs 20 \
|
| 280 |
+
--max_steps 300 \
|
| 281 |
+
--batch_size 8 \
|
| 282 |
+
--lr 3e-5 \
|
| 283 |
+
--saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
|
| 284 |
+
--log_root /workspace/SimToken/log \
|
| 285 |
+
--checkpoint_root /workspace/SimToken/checkpoints \
|
| 286 |
+
--log_gate_stats_every 50 \
|
| 287 |
+
--skip_eval_after_train \
|
| 288 |
+
--save_gate_only \
|
| 289 |
+
2>&1 | tee /workspace/SimToken/log/cached_gate_train256_s300_lr3e5.stdout
|
| 290 |
+
```
|
| 291 |
+
|
| 292 |
+
训练中重点观察:
|
| 293 |
+
|
| 294 |
+
```text
|
| 295 |
+
batch_miou / batch_fscore 是否逐步改善
|
| 296 |
+
proj_norm 是否持续增长
|
| 297 |
+
alpha_std 是否温和分化
|
| 298 |
+
Null 风险:alpha 是否出现极端偏移
|
| 299 |
+
```
|
| 300 |
+
|
| 301 |
+
如果 `proj_norm` 在前 100 steps 仍接近 0,说明 lr=3e-5 可能过小,可以改回 1e-4 或使用分层 lr。
|
| 302 |
+
|
| 303 |
+
### Step 2:评估 cached train256 gate checkpoint
|
| 304 |
+
|
| 305 |
+
```bash
|
| 306 |
+
for split in test_s test_u test_n; do
|
| 307 |
+
TRANSFORMERS_OFFLINE=1 python -u -W ignore train_cached_gate.py \
|
| 308 |
+
--cache_split $split \
|
| 309 |
+
--cache_root /workspace/SimToken/cache_q \
|
| 310 |
+
--batch_size 8 \
|
| 311 |
+
--saved_model /workspace/SimToken/checkpoints/simtoken_pretrained.pth \
|
| 312 |
+
--gate_checkpoint /workspace/SimToken/checkpoints/cached_gate_train256_s300_lr3e5.pth \
|
| 313 |
+
--eval_only \
|
| 314 |
+
--name cached_gate_train256_s300_lr3e5_${split}_200 \
|
| 315 |
+
2>&1 | tee /workspace/SimToken/log/cached_gate_train256_s300_lr3e5_${split}_200.stdout
|
| 316 |
+
done
|
| 317 |
+
```
|
| 318 |
+
|
| 319 |
+
对比 baseline 使用 3.4 中 identity 200 条结果。
|
| 320 |
+
|
| 321 |
+
### Step 3:根据结果决策
|
| 322 |
+
|
| 323 |
+
判断标准:
|
| 324 |
+
|
| 325 |
+
- Seen / Unseen 都提升:进入更大 cached train 或完整 A-min;
|
| 326 |
+
- Seen 提升、Unseen 不提升:gate 仍可能学 dataset pattern,需要更多 train cache 或更强正则;
|
| 327 |
+
- Seen / Unseen 都下降:不要跑完整 A-min,先调 lr、正则或 gate 容量;
|
| 328 |
+
- Null metric 保持 `< 0.07`:暂不加 area penalty;
|
| 329 |
+
- Null metric 超过 `0.10`:强危险信号,需要 area penalty 或约束预测面积。
|
| 330 |
+
|
| 331 |
+
如果 train256 有弱正收益但幅度小,先看 alpha 分布和 hard/easy frames,而不是立刻扩大。若 alpha 在所有帧上几乎一致,可能只是全局偏置;若 hard frames alpha 系统性更高,说明更像 referent anchoring。
|
| 332 |
+
|
| 333 |
+
---
|
| 334 |
+
|
| 335 |
+
## 5. 成功标准
|
| 336 |
+
|
| 337 |
+
A-min 成功不能只看总体 mIoU,需要同时满足:
|
| 338 |
+
|
| 339 |
+
1. Seen / Unseen mIoU 稳定提升;
|
| 340 |
+
2. Unseen 至少不弱于 Seen 的提升趋势;
|
| 341 |
+
3. Null 指标不恶化,预测面积不膨胀;
|
| 342 |
+
4. hard frames 改善更明显;
|
| 343 |
+
5. 如果记录 gate alpha,hard frames 的 alpha 应系统性高于 easy frames。
|
| 344 |
+
|
| 345 |
+
失败解释:
|
| 346 |
+
|
| 347 |
+
- 如果 Seen 提升、Unseen 不提升:可能是 gate 学到数据集模式,而不是 referent anchoring;
|
| 348 |
+
- 如果 Null 恶化:gate 可能放大了通用前景显著性;
|
| 349 |
+
- 如果 gate-only 无变化但完整 A-min 有收益:说明 gate 需要与 mask decoder / text projection 协同适配;
|
| 350 |
+
- 如果全 split 下降:gate 插入位置、初始化或学习率需要重新检查。
|
| 351 |
+
|
| 352 |
+
---
|
| 353 |
+
|
| 354 |
+
## 6. 后续机制分析
|
| 355 |
+
|
| 356 |
+
如果 A-min 有正收益,再做 hook 分析:
|
| 357 |
+
|
| 358 |
+
1. sparse self-attention 中 `mask_token -> q`;
|
| 359 |
+
2. token-to-image attention 中 mask token 对 image tokens 的关注;
|
| 360 |
+
3. A-min 前后 hard/easy frames 的 gate alpha;
|
| 361 |
+
4. `s_pred(beta=1.0)` 与 gate alpha 的关系。
|
| 362 |
+
|
| 363 |
+
这部分用于论文解释,不作为当前阻塞项。
|
| 364 |
+
|
| 365 |
+
---
|
| 366 |
+
|
| 367 |
+
## 7. 当前一句话结论
|
| 368 |
+
|
| 369 |
+
> A-min gate 的梯度路径、表达能力和 cached pipeline 一致性已经通过验证;overfit32 能显著提升训练集但不能泛化。当前主线是用更大 cached train set(已完成 train256 cache)验证 gate-only 泛化,再决定是否投入完整 A-min 联合训练。
|
target_frame_sweep.py
ADDED
|
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import csv
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
from functools import partial
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import transformers
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
|
| 12 |
+
from configs import args
|
| 13 |
+
from datasets import REFAVS
|
| 14 |
+
from decoder_invariance_check import build_model, set_seed
|
| 15 |
+
from load_model import collate_fn, dict_to_cuda
|
| 16 |
+
from utils import utility
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def decode_with_q(model, batch, q):
|
| 20 |
+
visual_model = model.get_model().visual_model
|
| 21 |
+
image_embeddings = batch["image_feats"][0]
|
| 22 |
+
|
| 23 |
+
sparse, dense = visual_model.prompt_encoder(
|
| 24 |
+
points=None,
|
| 25 |
+
boxes=None,
|
| 26 |
+
masks=None,
|
| 27 |
+
text_embeds=q.unsqueeze(1),
|
| 28 |
+
)
|
| 29 |
+
sparse = sparse.to(q.dtype)
|
| 30 |
+
dense = dense.to(q.dtype)
|
| 31 |
+
|
| 32 |
+
low_res_masks, iou_predictions = visual_model.mask_decoder(
|
| 33 |
+
image_embeddings=image_embeddings,
|
| 34 |
+
image_pe=visual_model.prompt_encoder.get_dense_pe(),
|
| 35 |
+
sparse_prompt_embeddings=sparse,
|
| 36 |
+
dense_prompt_embeddings=dense,
|
| 37 |
+
multimask_output=False,
|
| 38 |
+
)
|
| 39 |
+
pred_masks = visual_model.postprocess_masks(
|
| 40 |
+
low_res_masks,
|
| 41 |
+
input_size=batch["resizes"][0],
|
| 42 |
+
original_size=batch["orgsizes"][0],
|
| 43 |
+
).squeeze(1)
|
| 44 |
+
return pred_masks.unsqueeze(0), iou_predictions.squeeze(-1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_q_for_target_frame(model, batch, target_frame):
|
| 48 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 49 |
+
output = model.forward(
|
| 50 |
+
images=batch["images"],
|
| 51 |
+
images_clip=batch["images_clip"],
|
| 52 |
+
audio_features=batch["audio_feats"],
|
| 53 |
+
image_features=batch["image_feats"],
|
| 54 |
+
input_ids=batch["input_ids"],
|
| 55 |
+
labels=batch["labels"],
|
| 56 |
+
attention_masks=batch["attention_masks"],
|
| 57 |
+
masks_list=batch["masks"],
|
| 58 |
+
resize_list=batch["resizes"],
|
| 59 |
+
orgsize_list=batch["orgsizes"],
|
| 60 |
+
conversation_list=batch["convs"],
|
| 61 |
+
refs_num=batch["refs_num"],
|
| 62 |
+
fids=batch["fids"],
|
| 63 |
+
vids=batch["vids"],
|
| 64 |
+
contrast=args.ct_weight,
|
| 65 |
+
ref_ids=batch["ref_ids"],
|
| 66 |
+
inference=True,
|
| 67 |
+
target_frame=target_frame,
|
| 68 |
+
)
|
| 69 |
+
return output["seg_embeddings"][0][0:1]
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def mask_area(pred_masks):
|
| 73 |
+
return (torch.sigmoid(pred_masks) > 0.4).float().mean().item()
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def mean_mask_iou_to_others(mask, other_masks):
|
| 77 |
+
if not other_masks:
|
| 78 |
+
return 1.0
|
| 79 |
+
binary = (torch.sigmoid(mask) > 0.4).float()
|
| 80 |
+
other_binary = [(torch.sigmoid(m) > 0.4).float() for m in other_masks]
|
| 81 |
+
vals = []
|
| 82 |
+
for other in other_binary:
|
| 83 |
+
inter = (binary * other).sum()
|
| 84 |
+
union = torch.maximum(binary, other).sum()
|
| 85 |
+
vals.append((inter / (union + 1e-7)).item())
|
| 86 |
+
return float(np.mean(vals))
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def evaluate_one_sample(model, batch, sample_idx):
|
| 90 |
+
rows = []
|
| 91 |
+
qs = []
|
| 92 |
+
pred_masks_by_tf = []
|
| 93 |
+
|
| 94 |
+
gt_masks = batch["masks"][0]
|
| 95 |
+
vid = batch["vids"][0]
|
| 96 |
+
ref = batch["refs"][0][0]
|
| 97 |
+
|
| 98 |
+
for target_frame in range(args.frame_n):
|
| 99 |
+
q = get_q_for_target_frame(model, batch, target_frame)
|
| 100 |
+
qs.append(q.float().squeeze(0))
|
| 101 |
+
|
| 102 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 103 |
+
pred_masks, iou_predictions = decode_with_q(model, batch, q)
|
| 104 |
+
pred_masks_by_tf.append(pred_masks.detach())
|
| 105 |
+
|
| 106 |
+
miou = utility.mask_iou(pred_masks.float(), gt_masks.float())
|
| 107 |
+
fscore = utility.Eval_Fmeasure(pred_masks.float(), gt_masks.float(), None)
|
| 108 |
+
null_metric = utility.metric_s_for_null(pred_masks.float())
|
| 109 |
+
area = mask_area(pred_masks)
|
| 110 |
+
mean_iou_pred = iou_predictions.float().mean().item()
|
| 111 |
+
|
| 112 |
+
rows.append(
|
| 113 |
+
{
|
| 114 |
+
"sample_idx": sample_idx,
|
| 115 |
+
"vid": vid,
|
| 116 |
+
"ref": ref,
|
| 117 |
+
"target_frame": target_frame,
|
| 118 |
+
"mean_iou_pred": mean_iou_pred,
|
| 119 |
+
"mask_area": area,
|
| 120 |
+
"null_metric": float(null_metric),
|
| 121 |
+
"miou": miou,
|
| 122 |
+
"fscore": fscore,
|
| 123 |
+
"cos_to_q5": 0.0,
|
| 124 |
+
"mean_cos_to_other_q": 0.0,
|
| 125 |
+
"mean_mask_iou_to_other_tf": 0.0,
|
| 126 |
+
}
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
q_stack = F.normalize(torch.stack(qs, dim=0), dim=-1)
|
| 130 |
+
q_cos = q_stack @ q_stack.T
|
| 131 |
+
q5_idx = min(5, len(qs) - 1)
|
| 132 |
+
|
| 133 |
+
for i, row in enumerate(rows):
|
| 134 |
+
other = [j for j in range(len(rows)) if j != i]
|
| 135 |
+
row["cos_to_q5"] = q_cos[i, q5_idx].item()
|
| 136 |
+
row["mean_cos_to_other_q"] = q_cos[i, other].mean().item()
|
| 137 |
+
row["mean_mask_iou_to_other_tf"] = mean_mask_iou_to_others(
|
| 138 |
+
pred_masks_by_tf[i], [pred_masks_by_tf[j] for j in other]
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
return rows
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def print_sample_summary(rows):
|
| 145 |
+
print(f"\nSample {rows[0]['sample_idx']}: vid={rows[0]['vid']} ref={rows[0]['ref']}")
|
| 146 |
+
print("tf | miou | fscore | null_s | iou_pred | area | cos_to_q5 | mean_q_cos")
|
| 147 |
+
for row in rows:
|
| 148 |
+
print(
|
| 149 |
+
f"{row['target_frame']:02d} | "
|
| 150 |
+
f"{row['miou']:.4f} | "
|
| 151 |
+
f"{row['fscore']:.4f} | "
|
| 152 |
+
f"{row['null_metric']:.4f} | "
|
| 153 |
+
f"{row['mean_iou_pred']:.4f} | "
|
| 154 |
+
f"{row['mask_area']:.4f} | "
|
| 155 |
+
f"{row['cos_to_q5']:.4f} | "
|
| 156 |
+
f"{row['mean_cos_to_other_q']:.4f}"
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
best_miou = max(rows, key=lambda x: x["miou"])
|
| 160 |
+
best_iou_pred = max(rows, key=lambda x: x["mean_iou_pred"])
|
| 161 |
+
fixed = rows[min(5, len(rows) - 1)]
|
| 162 |
+
miou_values = [row["miou"] for row in rows]
|
| 163 |
+
q5_values = [row["cos_to_q5"] for row in rows]
|
| 164 |
+
print(
|
| 165 |
+
"Best miou tf="
|
| 166 |
+
f"{best_miou['target_frame']} ({best_miou['miou']:.4f}); "
|
| 167 |
+
"best iou_pred tf="
|
| 168 |
+
f"{best_iou_pred['target_frame']} ({best_iou_pred['mean_iou_pred']:.4f}); "
|
| 169 |
+
f"fixed tf=5 miou={fixed['miou']:.4f}"
|
| 170 |
+
)
|
| 171 |
+
print(
|
| 172 |
+
f"target-frame miou range={max(miou_values) - min(miou_values):.4f}; "
|
| 173 |
+
f"min cos_to_q5={min(q5_values):.4f}"
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def main():
|
| 178 |
+
set_seed(42)
|
| 179 |
+
torch.set_grad_enabled(False)
|
| 180 |
+
|
| 181 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 182 |
+
args.mllm,
|
| 183 |
+
cache_dir=None,
|
| 184 |
+
model_max_length=2048,
|
| 185 |
+
padding_side="right",
|
| 186 |
+
use_fast=False,
|
| 187 |
+
)
|
| 188 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 189 |
+
tokenizer.add_tokens("[SEG]")
|
| 190 |
+
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 191 |
+
|
| 192 |
+
dataset = REFAVS(args.eval_split, args, tokenizer, input_type="refer")
|
| 193 |
+
loader = DataLoader(
|
| 194 |
+
dataset,
|
| 195 |
+
batch_size=1,
|
| 196 |
+
shuffle=False,
|
| 197 |
+
num_workers=0,
|
| 198 |
+
collate_fn=partial(collate_fn, tokenizer=tokenizer),
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
limit = args.max_eval_rows if args.max_eval_rows > 0 else 1
|
| 202 |
+
print(f"Split: {args.eval_split} | samples to sweep: {limit}")
|
| 203 |
+
|
| 204 |
+
model = build_model(tokenizer, seg_token_idx)
|
| 205 |
+
|
| 206 |
+
all_rows = []
|
| 207 |
+
for sample_idx, batch in enumerate(loader):
|
| 208 |
+
if sample_idx >= limit:
|
| 209 |
+
break
|
| 210 |
+
batch = dict_to_cuda(batch)
|
| 211 |
+
rows = evaluate_one_sample(model, batch, sample_idx)
|
| 212 |
+
all_rows.extend(rows)
|
| 213 |
+
print_sample_summary(rows)
|
| 214 |
+
|
| 215 |
+
if not all_rows:
|
| 216 |
+
raise RuntimeError("No rows were checked. Is the selected split empty?")
|
| 217 |
+
|
| 218 |
+
fixed_rows = [r for r in all_rows if r["target_frame"] == min(5, args.frame_n - 1)]
|
| 219 |
+
oracle_by_sample = {}
|
| 220 |
+
iou_pred_by_sample = {}
|
| 221 |
+
for row in all_rows:
|
| 222 |
+
key = row["sample_idx"]
|
| 223 |
+
if key not in oracle_by_sample or row["miou"] > oracle_by_sample[key]["miou"]:
|
| 224 |
+
oracle_by_sample[key] = row
|
| 225 |
+
if key not in iou_pred_by_sample or row["mean_iou_pred"] > iou_pred_by_sample[key]["mean_iou_pred"]:
|
| 226 |
+
iou_pred_by_sample[key] = row
|
| 227 |
+
|
| 228 |
+
fixed_miou = np.mean([r["miou"] for r in fixed_rows])
|
| 229 |
+
fixed_null_metric = np.mean([r["null_metric"] for r in fixed_rows])
|
| 230 |
+
oracle_miou = np.mean([r["miou"] for r in oracle_by_sample.values()])
|
| 231 |
+
iou_pred_selected_miou = np.mean([r["miou"] for r in iou_pred_by_sample.values()])
|
| 232 |
+
min_cos_to_q5 = np.mean(
|
| 233 |
+
[min(r["cos_to_q5"] for r in all_rows if r["sample_idx"] == sample_idx) for sample_idx in oracle_by_sample]
|
| 234 |
+
)
|
| 235 |
+
mean_miou_range = np.mean(
|
| 236 |
+
[
|
| 237 |
+
max(r["miou"] for r in all_rows if r["sample_idx"] == sample_idx)
|
| 238 |
+
- min(r["miou"] for r in all_rows if r["sample_idx"] == sample_idx)
|
| 239 |
+
for sample_idx in oracle_by_sample
|
| 240 |
+
]
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
print("\nSummary")
|
| 244 |
+
print(f"samples: {len(fixed_rows)}")
|
| 245 |
+
print(f"fixed target_frame=5 mean miou: {fixed_miou:.4f}")
|
| 246 |
+
print(f"fixed target_frame=5 mean null_s: {fixed_null_metric:.4f}")
|
| 247 |
+
print(f"oracle best-target-frame mean miou: {oracle_miou:.4f}")
|
| 248 |
+
print(f"best-by-iou_pred selected mean miou: {iou_pred_selected_miou:.4f}")
|
| 249 |
+
print(f"oracle gain over fixed: {oracle_miou - fixed_miou:+.4f}")
|
| 250 |
+
print(f"iou_pred-selection gain over fixed: {iou_pred_selected_miou - fixed_miou:+.4f}")
|
| 251 |
+
print(f"mean target-frame miou range: {mean_miou_range:.4f}")
|
| 252 |
+
print(f"mean sample min cos_to_q5: {min_cos_to_q5:.4f}")
|
| 253 |
+
|
| 254 |
+
csv_path = os.environ.get("TARGET_FRAME_SWEEP_CSV")
|
| 255 |
+
if csv_path:
|
| 256 |
+
os.makedirs(os.path.dirname(os.path.abspath(csv_path)), exist_ok=True)
|
| 257 |
+
with open(csv_path, "w", newline="") as f:
|
| 258 |
+
writer = csv.DictWriter(f, fieldnames=list(all_rows[0].keys()))
|
| 259 |
+
writer.writeheader()
|
| 260 |
+
writer.writerows(all_rows)
|
| 261 |
+
print(f"Saved CSV: {csv_path}")
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
if __name__ == "__main__":
|
| 265 |
+
main()
|
train.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
import transformers
|
| 2 |
from datasets import REFAVS
|
| 3 |
from configs import args
|
| 4 |
-
from torch.utils.data import DataLoader
|
| 5 |
from functools import partial
|
| 6 |
from models.llava import conversation as conversation_lib
|
| 7 |
# from models.avs_model import VISAForCausalLM
|
|
@@ -21,6 +21,7 @@ import numpy as np
|
|
| 21 |
import re
|
| 22 |
import time
|
| 23 |
import os
|
|
|
|
| 24 |
|
| 25 |
|
| 26 |
import warnings
|
|
@@ -235,11 +236,19 @@ if __name__ == "__main__":
|
|
| 235 |
val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
|
| 236 |
val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
|
| 237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 238 |
|
| 239 |
g = torch.Generator()
|
| 240 |
g.manual_seed(42)
|
| 241 |
|
| 242 |
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, worker_init_fn=seed_worker,collate_fn=partial(collate_fn, tokenizer=tokenizer), generator=g)
|
|
|
|
| 243 |
|
| 244 |
val_dataloader_s_refer = DataLoader(val_dataset_s_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
|
| 245 |
val_dataloader_u_refer = DataLoader(val_dataset_u_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
|
|
@@ -349,6 +358,11 @@ if __name__ == "__main__":
|
|
| 349 |
model = model.to("cuda")
|
| 350 |
model.resize_token_embeddings(len(tokenizer))
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
|
| 353 |
for name, param in model.audio_feature_layer.named_parameters():
|
| 354 |
param.requires_grad = True
|
|
@@ -366,9 +380,113 @@ if __name__ == "__main__":
|
|
| 366 |
):
|
| 367 |
p.requires_grad = True
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
|
| 370 |
print("will save train model")
|
| 371 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
def valuate(model, dataloader, args, name):
|
| 373 |
model.eval()
|
| 374 |
|
|
@@ -420,11 +538,17 @@ if __name__ == "__main__":
|
|
| 420 |
epochs = args.epochs
|
| 421 |
print("init lr:", args.lr)
|
| 422 |
optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
|
|
|
|
| 423 |
|
| 424 |
-
gradient_accumulation_steps = int(16 // args.batch_size)
|
| 425 |
-
step_per_epoch = len(train_dataloader) // gradient_accumulation_steps
|
| 426 |
-
|
|
|
|
| 427 |
warmup_steps = int(total_steps * 0.1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
|
| 429 |
scheduler = get_cosine_schedule_with_warmup(
|
| 430 |
optimizer,
|
|
@@ -433,6 +557,9 @@ if __name__ == "__main__":
|
|
| 433 |
)
|
| 434 |
|
| 435 |
|
|
|
|
|
|
|
|
|
|
| 436 |
for epoch in range(epochs):
|
| 437 |
|
| 438 |
model.train()
|
|
@@ -441,6 +568,9 @@ if __name__ == "__main__":
|
|
| 441 |
|
| 442 |
loop = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}")
|
| 443 |
for step, batch in enumerate(loop):
|
|
|
|
|
|
|
|
|
|
| 444 |
input_dict = dict_to_cuda(batch)
|
| 445 |
output_dict = model.forward(images=input_dict["images"],
|
| 446 |
images_clip=input_dict["images_clip"],
|
|
@@ -459,6 +589,7 @@ if __name__ == "__main__":
|
|
| 459 |
contrast=args.ct_weight,
|
| 460 |
ref_ids=input_dict["ref_ids"],
|
| 461 |
epoch=epoch,
|
|
|
|
| 462 |
inference=False)
|
| 463 |
|
| 464 |
loss = output_dict["loss"]
|
|
@@ -468,6 +599,15 @@ if __name__ == "__main__":
|
|
| 468 |
|
| 469 |
|
| 470 |
if (step + 1) % gradient_accumulation_steps == 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
optimizer.step()
|
| 472 |
scheduler.step()
|
| 473 |
optimizer.zero_grad()
|
|
@@ -475,16 +615,33 @@ if __name__ == "__main__":
|
|
| 475 |
current_lr = scheduler.get_lr()[0]
|
| 476 |
loop.set_postfix(lr=current_lr, loss=running_loss / ((step + 1) / gradient_accumulation_steps))
|
| 477 |
|
| 478 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 479 |
|
| 480 |
|
| 481 |
with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
|
| 482 |
-
f.write(f"Epoch {epoch}: running_loss {running_loss /
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
|
| 484 |
|
| 485 |
torch.save(model.state_dict(), os.path.join(args.checkpoint_root, f"{args.name}.pth"))
|
| 486 |
print(f"trained model saved as {args.name}.pth")
|
| 487 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 488 |
# ---------------test on seen & unseen ------------------------------------------
|
| 489 |
model.eval()
|
| 490 |
|
|
@@ -531,4 +688,4 @@ if __name__ == "__main__":
|
|
| 531 |
print(f"\n valuate on test_n_refer, metric: {total_metric/count}")
|
| 532 |
|
| 533 |
with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
|
| 534 |
-
f.write(f"\n valuate on test_n_refer: metric {total_metric/count} \n")
|
|
|
|
| 1 |
import transformers
|
| 2 |
from datasets import REFAVS
|
| 3 |
from configs import args
|
| 4 |
+
from torch.utils.data import DataLoader, Subset
|
| 5 |
from functools import partial
|
| 6 |
from models.llava import conversation as conversation_lib
|
| 7 |
# from models.avs_model import VISAForCausalLM
|
|
|
|
| 21 |
import re
|
| 22 |
import time
|
| 23 |
import os
|
| 24 |
+
import sys
|
| 25 |
|
| 26 |
|
| 27 |
import warnings
|
|
|
|
| 236 |
val_dataset_u_refer = REFAVS('test_u', args, tokenizer, input_type='refer')
|
| 237 |
val_dataset_n_refer = REFAVS('test_n', args, tokenizer, input_type='refer')
|
| 238 |
|
| 239 |
+
if args.overfit_samples > 0:
|
| 240 |
+
overfit_n = min(args.overfit_samples, len(train_dataset))
|
| 241 |
+
train_dataset = Subset(train_dataset, list(range(overfit_n)))
|
| 242 |
+
print(f"overfit_samples enabled: using first {overfit_n} train samples")
|
| 243 |
+
|
| 244 |
+
train_eval_dataset = train_dataset
|
| 245 |
+
|
| 246 |
|
| 247 |
g = torch.Generator()
|
| 248 |
g.manual_seed(42)
|
| 249 |
|
| 250 |
train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=8, worker_init_fn=seed_worker,collate_fn=partial(collate_fn, tokenizer=tokenizer), generator=g)
|
| 251 |
+
train_eval_dataloader = DataLoader(train_eval_dataset, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
|
| 252 |
|
| 253 |
val_dataloader_s_refer = DataLoader(val_dataset_s_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
|
| 254 |
val_dataloader_u_refer = DataLoader(val_dataset_u_refer, batch_size=4, shuffle=False, num_workers=0, collate_fn=partial(collate_fn, tokenizer=tokenizer))
|
|
|
|
| 358 |
model = model.to("cuda")
|
| 359 |
model.resize_token_embeddings(len(tokenizer))
|
| 360 |
|
| 361 |
+
if args.init_from_saved_model or args.gate_only:
|
| 362 |
+
state = torch.load(args.saved_model, map_location="cpu")
|
| 363 |
+
missing, unexpected = model.load_state_dict(state, strict=False)
|
| 364 |
+
print(f"initialized training from saved model: {args.saved_model}")
|
| 365 |
+
print(f"missing keys: {len(missing)} | unexpected keys: {len(unexpected)}")
|
| 366 |
|
| 367 |
for name, param in model.audio_feature_layer.named_parameters():
|
| 368 |
param.requires_grad = True
|
|
|
|
| 380 |
):
|
| 381 |
p.requires_grad = True
|
| 382 |
|
| 383 |
+
if args.gate_only:
|
| 384 |
+
for p in model.parameters():
|
| 385 |
+
p.requires_grad = False
|
| 386 |
+
for n, p in model.named_parameters():
|
| 387 |
+
if "referent_gate" in n:
|
| 388 |
+
p.requires_grad = True
|
| 389 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 390 |
+
total = sum(p.numel() for p in model.parameters())
|
| 391 |
+
print(f"gate_only enabled: trainable params {trainable} / {total}")
|
| 392 |
|
| 393 |
print("will save train model")
|
| 394 |
|
| 395 |
+
def _total_norm(values):
|
| 396 |
+
if not values:
|
| 397 |
+
return 0.0
|
| 398 |
+
return float(sum(v * v for v in values) ** 0.5)
|
| 399 |
+
|
| 400 |
+
def collect_referent_gate_stats(model):
|
| 401 |
+
gate_modules = [(n, m) for n, m in model.named_modules() if n.endswith("referent_gate")]
|
| 402 |
+
proj_norms = []
|
| 403 |
+
gate_norms = []
|
| 404 |
+
proj_grad_norms = []
|
| 405 |
+
gate_grad_norms = []
|
| 406 |
+
alpha_tensors = []
|
| 407 |
+
|
| 408 |
+
for _, module in gate_modules:
|
| 409 |
+
proj_norms.append(module.proj.weight.detach().float().norm().item())
|
| 410 |
+
gate_norms.append(module.gate.weight.detach().float().norm().item())
|
| 411 |
+
if module.proj.weight.grad is not None:
|
| 412 |
+
proj_grad_norms.append(module.proj.weight.grad.detach().float().norm().item())
|
| 413 |
+
if module.gate.weight.grad is not None:
|
| 414 |
+
gate_grad_norms.append(module.gate.weight.grad.detach().float().norm().item())
|
| 415 |
+
if module.last_alpha is not None:
|
| 416 |
+
alpha_tensors.append(module.last_alpha.detach().float().reshape(-1))
|
| 417 |
+
|
| 418 |
+
stats = {
|
| 419 |
+
"modules": len(gate_modules),
|
| 420 |
+
"proj_norm": _total_norm(proj_norms),
|
| 421 |
+
"gate_norm": _total_norm(gate_norms),
|
| 422 |
+
"proj_grad_norm": _total_norm(proj_grad_norms),
|
| 423 |
+
"gate_grad_norm": _total_norm(gate_grad_norms),
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
if alpha_tensors:
|
| 427 |
+
alpha = torch.cat(alpha_tensors)
|
| 428 |
+
stats.update(
|
| 429 |
+
{
|
| 430 |
+
"alpha_mean": alpha.mean().item(),
|
| 431 |
+
"alpha_std": alpha.std(unbiased=False).item(),
|
| 432 |
+
"alpha_min": alpha.min().item(),
|
| 433 |
+
"alpha_max": alpha.max().item(),
|
| 434 |
+
}
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
stats.update(
|
| 438 |
+
{
|
| 439 |
+
"alpha_mean": float("nan"),
|
| 440 |
+
"alpha_std": float("nan"),
|
| 441 |
+
"alpha_min": float("nan"),
|
| 442 |
+
"alpha_max": float("nan"),
|
| 443 |
+
}
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
return stats
|
| 447 |
+
|
| 448 |
+
def print_referent_gate_optimizer_sanity(model, optimizer):
|
| 449 |
+
optimizer_param_ids = {id(p) for group in optimizer.param_groups for p in group["params"]}
|
| 450 |
+
gate_params = [(n, p) for n, p in model.named_parameters() if "referent_gate" in n]
|
| 451 |
+
trainable_gate = [(n, p) for n, p in gate_params if p.requires_grad]
|
| 452 |
+
optimizer_gate = [(n, p) for n, p in gate_params if id(p) in optimizer_param_ids]
|
| 453 |
+
optimizer_trainable_gate = [
|
| 454 |
+
(n, p) for n, p in gate_params if p.requires_grad and id(p) in optimizer_param_ids
|
| 455 |
+
]
|
| 456 |
+
print(
|
| 457 |
+
"referent_gate sanity: "
|
| 458 |
+
f"params={sum(p.numel() for _, p in gate_params)} | "
|
| 459 |
+
f"trainable={sum(p.numel() for _, p in trainable_gate)} | "
|
| 460 |
+
f"in_optimizer={sum(p.numel() for _, p in optimizer_gate)} | "
|
| 461 |
+
f"trainable_in_optimizer={sum(p.numel() for _, p in optimizer_trainable_gate)}"
|
| 462 |
+
)
|
| 463 |
+
|
| 464 |
+
stats = collect_referent_gate_stats(model)
|
| 465 |
+
print(
|
| 466 |
+
"referent_gate init stats: "
|
| 467 |
+
f"modules={stats['modules']} | "
|
| 468 |
+
f"proj_norm={stats['proj_norm']:.6f} | "
|
| 469 |
+
f"gate_norm={stats['gate_norm']:.6f}"
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
def log_referent_gate_stats(global_step, loss_value):
|
| 473 |
+
stats = collect_referent_gate_stats(model)
|
| 474 |
+
message = (
|
| 475 |
+
f"gate_stats step={global_step} "
|
| 476 |
+
f"loss={loss_value:.6f} "
|
| 477 |
+
f"proj_norm={stats['proj_norm']:.6f} "
|
| 478 |
+
f"gate_norm={stats['gate_norm']:.6f} "
|
| 479 |
+
f"proj_grad_norm={stats['proj_grad_norm']:.6f} "
|
| 480 |
+
f"gate_grad_norm={stats['gate_grad_norm']:.6f} "
|
| 481 |
+
f"alpha_mean={stats['alpha_mean']:.4f} "
|
| 482 |
+
f"alpha_std={stats['alpha_std']:.4f} "
|
| 483 |
+
f"alpha_min={stats['alpha_min']:.4f} "
|
| 484 |
+
f"alpha_max={stats['alpha_max']:.4f}"
|
| 485 |
+
)
|
| 486 |
+
print(message)
|
| 487 |
+
with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
|
| 488 |
+
f.write(message + "\n")
|
| 489 |
+
|
| 490 |
def valuate(model, dataloader, args, name):
|
| 491 |
model.eval()
|
| 492 |
|
|
|
|
| 538 |
epochs = args.epochs
|
| 539 |
print("init lr:", args.lr)
|
| 540 |
optimizer = AdamW(model.parameters(), lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
|
| 541 |
+
print_referent_gate_optimizer_sanity(model, optimizer)
|
| 542 |
|
| 543 |
+
gradient_accumulation_steps = max(1, int(16 // args.batch_size))
|
| 544 |
+
step_per_epoch = max(1, len(train_dataloader) // gradient_accumulation_steps)
|
| 545 |
+
full_total_steps = epochs * step_per_epoch
|
| 546 |
+
total_steps = min(args.max_steps, full_total_steps) if args.max_steps > 0 else full_total_steps
|
| 547 |
warmup_steps = int(total_steps * 0.1)
|
| 548 |
+
print(
|
| 549 |
+
f"training schedule: grad_accum={gradient_accumulation_steps} | "
|
| 550 |
+
f"step_per_epoch={step_per_epoch} | total_optimizer_steps={total_steps}"
|
| 551 |
+
)
|
| 552 |
|
| 553 |
scheduler = get_cosine_schedule_with_warmup(
|
| 554 |
optimizer,
|
|
|
|
| 557 |
)
|
| 558 |
|
| 559 |
|
| 560 |
+
optimizer_step_count = 0
|
| 561 |
+
stop_training = False
|
| 562 |
+
|
| 563 |
for epoch in range(epochs):
|
| 564 |
|
| 565 |
model.train()
|
|
|
|
| 568 |
|
| 569 |
loop = tqdm(train_dataloader, desc=f"Training Epoch {epoch + 1}/{epochs}")
|
| 570 |
for step, batch in enumerate(loop):
|
| 571 |
+
if args.max_steps > 0 and optimizer_step_count >= args.max_steps:
|
| 572 |
+
stop_training = True
|
| 573 |
+
break
|
| 574 |
input_dict = dict_to_cuda(batch)
|
| 575 |
output_dict = model.forward(images=input_dict["images"],
|
| 576 |
images_clip=input_dict["images_clip"],
|
|
|
|
| 589 |
contrast=args.ct_weight,
|
| 590 |
ref_ids=input_dict["ref_ids"],
|
| 591 |
epoch=epoch,
|
| 592 |
+
gate_only=args.gate_only,
|
| 593 |
inference=False)
|
| 594 |
|
| 595 |
loss = output_dict["loss"]
|
|
|
|
| 599 |
|
| 600 |
|
| 601 |
if (step + 1) % gradient_accumulation_steps == 0:
|
| 602 |
+
optimizer_step_count += 1
|
| 603 |
+
if (
|
| 604 |
+
args.log_gate_stats_every > 0
|
| 605 |
+
and optimizer_step_count % args.log_gate_stats_every == 0
|
| 606 |
+
):
|
| 607 |
+
log_referent_gate_stats(
|
| 608 |
+
optimizer_step_count,
|
| 609 |
+
loss.item() * gradient_accumulation_steps,
|
| 610 |
+
)
|
| 611 |
optimizer.step()
|
| 612 |
scheduler.step()
|
| 613 |
optimizer.zero_grad()
|
|
|
|
| 615 |
current_lr = scheduler.get_lr()[0]
|
| 616 |
loop.set_postfix(lr=current_lr, loss=running_loss / ((step + 1) / gradient_accumulation_steps))
|
| 617 |
|
| 618 |
+
if args.max_steps > 0 and optimizer_step_count >= args.max_steps:
|
| 619 |
+
stop_training = True
|
| 620 |
+
break
|
| 621 |
+
|
| 622 |
+
denom = max(1, optimizer_step_count)
|
| 623 |
+
print(f" Epoch {epoch + 1}, Loss:{running_loss / denom :.4f}, Learning Rate:{scheduler.get_last_lr()[0]:.6f}")
|
| 624 |
|
| 625 |
|
| 626 |
with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
|
| 627 |
+
f.write(f"Epoch {epoch}: running_loss {running_loss / denom} Learning Rate:{scheduler.get_last_lr()[0]:.6f}\n")
|
| 628 |
+
|
| 629 |
+
if stop_training:
|
| 630 |
+
print(f"stopped early at optimizer step {optimizer_step_count}")
|
| 631 |
+
break
|
| 632 |
|
| 633 |
|
| 634 |
torch.save(model.state_dict(), os.path.join(args.checkpoint_root, f"{args.name}.pth"))
|
| 635 |
print(f"trained model saved as {args.name}.pth")
|
| 636 |
|
| 637 |
+
if args.skip_eval_after_train:
|
| 638 |
+
print("skip_eval_after_train enabled: exiting after checkpoint save")
|
| 639 |
+
sys.exit(0)
|
| 640 |
+
|
| 641 |
+
if args.eval_train_only:
|
| 642 |
+
valuate(model, train_eval_dataloader, args, 'train_overfit')
|
| 643 |
+
sys.exit(0)
|
| 644 |
+
|
| 645 |
# ---------------test on seen & unseen ------------------------------------------
|
| 646 |
model.eval()
|
| 647 |
|
|
|
|
| 688 |
print(f"\n valuate on test_n_refer, metric: {total_metric/count}")
|
| 689 |
|
| 690 |
with open(os.path.join(args.log_root, f'{args.name}.txt'), "a") as f:
|
| 691 |
+
f.write(f"\n valuate on test_n_refer: metric {total_metric/count} \n")
|
train_cached_gate.py
ADDED
|
@@ -0,0 +1,439 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
|
| 5 |
+
import cv2
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import transformers
|
| 9 |
+
from torch.optim import AdamW
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset, Subset
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
|
| 13 |
+
from configs import args
|
| 14 |
+
from decoder_invariance_check import build_model, set_seed
|
| 15 |
+
from models.avs_model import dice_loss, sigmoid_ce_loss
|
| 16 |
+
from utils import utility
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def _total_norm(values):
|
| 20 |
+
if not values:
|
| 21 |
+
return 0.0
|
| 22 |
+
return float(sum(v * v for v in values) ** 0.5)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def collect_referent_gate_stats(model):
|
| 26 |
+
gate_modules = [(n, m) for n, m in model.named_modules() if n.endswith("referent_gate")]
|
| 27 |
+
proj_norms = []
|
| 28 |
+
gate_norms = []
|
| 29 |
+
proj_grad_norms = []
|
| 30 |
+
gate_grad_norms = []
|
| 31 |
+
alpha_tensors = []
|
| 32 |
+
|
| 33 |
+
for _, module in gate_modules:
|
| 34 |
+
proj_norms.append(module.proj.weight.detach().float().norm().item())
|
| 35 |
+
gate_norms.append(module.gate.weight.detach().float().norm().item())
|
| 36 |
+
if module.proj.weight.grad is not None:
|
| 37 |
+
proj_grad_norms.append(module.proj.weight.grad.detach().float().norm().item())
|
| 38 |
+
if module.gate.weight.grad is not None:
|
| 39 |
+
gate_grad_norms.append(module.gate.weight.grad.detach().float().norm().item())
|
| 40 |
+
if module.last_alpha is not None:
|
| 41 |
+
alpha_tensors.append(module.last_alpha.detach().float().reshape(-1))
|
| 42 |
+
|
| 43 |
+
stats = {
|
| 44 |
+
"modules": len(gate_modules),
|
| 45 |
+
"proj_norm": _total_norm(proj_norms),
|
| 46 |
+
"gate_norm": _total_norm(gate_norms),
|
| 47 |
+
"proj_grad_norm": _total_norm(proj_grad_norms),
|
| 48 |
+
"gate_grad_norm": _total_norm(gate_grad_norms),
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
if alpha_tensors:
|
| 52 |
+
alpha = torch.cat(alpha_tensors)
|
| 53 |
+
stats.update(
|
| 54 |
+
{
|
| 55 |
+
"alpha_mean": alpha.mean().item(),
|
| 56 |
+
"alpha_std": alpha.std(unbiased=False).item(),
|
| 57 |
+
"alpha_min": alpha.min().item(),
|
| 58 |
+
"alpha_max": alpha.max().item(),
|
| 59 |
+
}
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
stats.update(
|
| 63 |
+
{
|
| 64 |
+
"alpha_mean": float("nan"),
|
| 65 |
+
"alpha_std": float("nan"),
|
| 66 |
+
"alpha_min": float("nan"),
|
| 67 |
+
"alpha_max": float("nan"),
|
| 68 |
+
}
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
return stats
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def zero_referent_gate(model):
|
| 75 |
+
with torch.no_grad():
|
| 76 |
+
for _, module in model.named_modules():
|
| 77 |
+
if not _.endswith("referent_gate"):
|
| 78 |
+
continue
|
| 79 |
+
module.gate.weight.zero_()
|
| 80 |
+
module.gate.bias.zero_()
|
| 81 |
+
module.proj.weight.zero_()
|
| 82 |
+
module.proj.bias.zero_()
|
| 83 |
+
module.last_alpha = None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def referent_gate_state_dict(model):
|
| 87 |
+
return {
|
| 88 |
+
name: param.detach().cpu()
|
| 89 |
+
for name, param in model.state_dict().items()
|
| 90 |
+
if "referent_gate" in name
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def load_referent_gate_checkpoint(model, path):
|
| 95 |
+
checkpoint = torch.load(path, map_location="cpu")
|
| 96 |
+
if isinstance(checkpoint, dict) and checkpoint.get("type") == "referent_gate_only":
|
| 97 |
+
checkpoint = checkpoint["state_dict"]
|
| 98 |
+
gate_state = {k: v for k, v in checkpoint.items() if "referent_gate" in k}
|
| 99 |
+
if not gate_state:
|
| 100 |
+
raise RuntimeError(f"No referent_gate parameters found in {path}")
|
| 101 |
+
current = model.state_dict()
|
| 102 |
+
missing_shape = [
|
| 103 |
+
k
|
| 104 |
+
for k, v in gate_state.items()
|
| 105 |
+
if k not in current or tuple(current[k].shape) != tuple(v.shape)
|
| 106 |
+
]
|
| 107 |
+
if missing_shape:
|
| 108 |
+
raise RuntimeError(f"Gate checkpoint has incompatible keys: {missing_shape[:5]}")
|
| 109 |
+
current.update(gate_state)
|
| 110 |
+
model.load_state_dict(current, strict=True)
|
| 111 |
+
print(f"loaded referent gate checkpoint: {path} ({len(gate_state)} tensors)")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def log_gate_stats(model, step, loss_value, batch_metrics=None):
|
| 115 |
+
stats = collect_referent_gate_stats(model)
|
| 116 |
+
metric_text = ""
|
| 117 |
+
if batch_metrics is not None:
|
| 118 |
+
metric_text = (
|
| 119 |
+
f"batch_miou={batch_metrics['miou']:.4f} "
|
| 120 |
+
f"batch_fscore={batch_metrics['fscore']:.4f} "
|
| 121 |
+
)
|
| 122 |
+
message = (
|
| 123 |
+
f"gate_stats step={step} "
|
| 124 |
+
f"loss={loss_value:.6f} "
|
| 125 |
+
f"{metric_text}"
|
| 126 |
+
f"proj_norm={stats['proj_norm']:.6f} "
|
| 127 |
+
f"gate_norm={stats['gate_norm']:.6f} "
|
| 128 |
+
f"proj_grad_norm={stats['proj_grad_norm']:.6f} "
|
| 129 |
+
f"gate_grad_norm={stats['gate_grad_norm']:.6f} "
|
| 130 |
+
f"alpha_mean={stats['alpha_mean']:.4f} "
|
| 131 |
+
f"alpha_std={stats['alpha_std']:.4f} "
|
| 132 |
+
f"alpha_min={stats['alpha_min']:.4f} "
|
| 133 |
+
f"alpha_max={stats['alpha_max']:.4f}"
|
| 134 |
+
)
|
| 135 |
+
print(message)
|
| 136 |
+
os.makedirs(args.log_root, exist_ok=True)
|
| 137 |
+
with open(os.path.join(args.log_root, f"{args.name}.txt"), "a") as f:
|
| 138 |
+
f.write(message + "\n")
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
class CachedQDataset(Dataset):
|
| 142 |
+
def __init__(self, split, cfg):
|
| 143 |
+
self.split = split
|
| 144 |
+
self.cfg = cfg
|
| 145 |
+
self.root = os.path.join(cfg.cache_root, split)
|
| 146 |
+
self.index_path = os.path.join(self.root, "index.jsonl")
|
| 147 |
+
if not os.path.exists(self.index_path):
|
| 148 |
+
raise FileNotFoundError(f"Missing cache index: {self.index_path}")
|
| 149 |
+
with open(self.index_path) as f:
|
| 150 |
+
self.rows = [json.loads(line) for line in f if line.strip()]
|
| 151 |
+
|
| 152 |
+
def __len__(self):
|
| 153 |
+
return len(self.rows)
|
| 154 |
+
|
| 155 |
+
def _load_masks(self, vid, fids):
|
| 156 |
+
masks = []
|
| 157 |
+
for fid in fids:
|
| 158 |
+
frames = []
|
| 159 |
+
for frame_idx in range(self.cfg.frame_n):
|
| 160 |
+
path = os.path.join(
|
| 161 |
+
self.cfg.data_dir,
|
| 162 |
+
"gt_mask",
|
| 163 |
+
vid,
|
| 164 |
+
f"fid_{int(fid)}",
|
| 165 |
+
f"0000{frame_idx}.png",
|
| 166 |
+
)
|
| 167 |
+
mask = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
|
| 168 |
+
if mask is None:
|
| 169 |
+
raise FileNotFoundError(path)
|
| 170 |
+
frames.append(torch.as_tensor(mask > 0, dtype=torch.float32))
|
| 171 |
+
masks.append(torch.stack(frames, dim=0))
|
| 172 |
+
return torch.stack(masks, dim=0)
|
| 173 |
+
|
| 174 |
+
def __getitem__(self, idx):
|
| 175 |
+
row = self.rows[idx]
|
| 176 |
+
cache = torch.load(os.path.join(self.root, row["path"]), map_location="cpu")
|
| 177 |
+
vid = cache["vid"]
|
| 178 |
+
return {
|
| 179 |
+
"sample_idx": cache["sample_idx"],
|
| 180 |
+
"vid": vid,
|
| 181 |
+
"refs": cache["refs"],
|
| 182 |
+
"fids": cache["fids"],
|
| 183 |
+
"q": cache["q"].float(),
|
| 184 |
+
"image_embeddings": torch.load(
|
| 185 |
+
os.path.join(self.cfg.data_dir, "image_embed", f"{vid}.pt"),
|
| 186 |
+
map_location="cpu",
|
| 187 |
+
).float(),
|
| 188 |
+
"gt_masks": self._load_masks(vid, cache["fids"]),
|
| 189 |
+
"resize": tuple(cache["resize"]),
|
| 190 |
+
"orgsize": tuple(cache["orgsize"]),
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def collate_cached(batch):
|
| 195 |
+
return batch
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
def decode_batch(visual_model, batch, device):
|
| 199 |
+
image_pe = visual_model.prompt_encoder.get_dense_pe()
|
| 200 |
+
frame_qs = []
|
| 201 |
+
frame_image_embeddings = []
|
| 202 |
+
prompt_spans = []
|
| 203 |
+
|
| 204 |
+
for sample_idx, sample in enumerate(batch):
|
| 205 |
+
q = sample["q"].to(device=device, dtype=torch.float32)
|
| 206 |
+
image_embeddings = sample["image_embeddings"].to(device=device, dtype=torch.float32)
|
| 207 |
+
frames = image_embeddings.shape[0]
|
| 208 |
+
for prompt_idx in range(q.shape[0]):
|
| 209 |
+
start = len(frame_qs) * frames
|
| 210 |
+
frame_qs.append(q[prompt_idx].unsqueeze(0).expand(frames, -1))
|
| 211 |
+
frame_image_embeddings.append(image_embeddings)
|
| 212 |
+
prompt_spans.append((sample_idx, prompt_idx, start, start + frames))
|
| 213 |
+
|
| 214 |
+
if not frame_qs:
|
| 215 |
+
raise RuntimeError("No cached prompts were provided for decoding.")
|
| 216 |
+
|
| 217 |
+
frame_qs = torch.cat(frame_qs, dim=0)
|
| 218 |
+
frame_image_embeddings = torch.cat(frame_image_embeddings, dim=0)
|
| 219 |
+
sparse_embeddings, dense_embeddings = visual_model.prompt_encoder(
|
| 220 |
+
points=None,
|
| 221 |
+
boxes=None,
|
| 222 |
+
masks=None,
|
| 223 |
+
text_embeds=frame_qs.unsqueeze(1),
|
| 224 |
+
)
|
| 225 |
+
sparse_embeddings = sparse_embeddings.to(frame_qs.dtype)
|
| 226 |
+
dense_embeddings = dense_embeddings.to(frame_qs.dtype)
|
| 227 |
+
|
| 228 |
+
low_res_masks = visual_model.mask_decoder.forward_modified_v3(
|
| 229 |
+
image_embeddings=frame_image_embeddings,
|
| 230 |
+
image_pe=image_pe,
|
| 231 |
+
sparse_prompt_embeddings=sparse_embeddings,
|
| 232 |
+
dense_prompt_embeddings=dense_embeddings,
|
| 233 |
+
).unsqueeze(1)
|
| 234 |
+
|
| 235 |
+
pred_by_sample = [[] for _ in batch]
|
| 236 |
+
for sample_idx, _, start, end in prompt_spans:
|
| 237 |
+
sample = batch[sample_idx]
|
| 238 |
+
pred_mask = visual_model.postprocess_masks(
|
| 239 |
+
low_res_masks[start:end],
|
| 240 |
+
input_size=sample["resize"],
|
| 241 |
+
original_size=sample["orgsize"],
|
| 242 |
+
)
|
| 243 |
+
pred_by_sample[sample_idx].append(pred_mask.squeeze(1))
|
| 244 |
+
|
| 245 |
+
return [torch.stack(pred_masks, dim=0) for pred_masks in pred_by_sample]
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def decode_sample(visual_model, sample, device):
|
| 249 |
+
return decode_batch(visual_model, [sample], device)[0]
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def compute_mask_loss(pred_masks, gt_masks):
|
| 253 |
+
mask_bce_loss = 0.0
|
| 254 |
+
mask_dice_loss = 0.0
|
| 255 |
+
num_masks = 0
|
| 256 |
+
|
| 257 |
+
for pred_mask, gt_mask in zip(pred_masks, gt_masks):
|
| 258 |
+
gt_mask = gt_mask.to(device=pred_mask.device, dtype=pred_mask.dtype)
|
| 259 |
+
num_seg, frames, height, width = gt_mask.shape
|
| 260 |
+
gt_flat = gt_mask.view(num_seg * frames, height, width)
|
| 261 |
+
pred_flat = pred_mask.view(num_seg * frames, height, width)
|
| 262 |
+
|
| 263 |
+
mask_bce_loss = mask_bce_loss + (
|
| 264 |
+
sigmoid_ce_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
|
| 265 |
+
* gt_flat.shape[0]
|
| 266 |
+
)
|
| 267 |
+
mask_dice_loss = mask_dice_loss + (
|
| 268 |
+
dice_loss(pred_flat, gt_flat, num_masks=gt_flat.shape[0])
|
| 269 |
+
* gt_flat.shape[0]
|
| 270 |
+
)
|
| 271 |
+
num_masks += gt_flat.shape[0]
|
| 272 |
+
|
| 273 |
+
mask_bce_loss = 2.0 * mask_bce_loss / (num_masks + 1e-8)
|
| 274 |
+
mask_dice_loss = 0.5 * mask_dice_loss / (num_masks + 1e-8)
|
| 275 |
+
return mask_bce_loss + mask_dice_loss
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
def compute_batch_metrics(pred_masks, gt_masks):
|
| 279 |
+
total_iou = 0.0
|
| 280 |
+
total_fscore = 0.0
|
| 281 |
+
count = 0
|
| 282 |
+
for pred_mask, gt_mask in zip(pred_masks, gt_masks):
|
| 283 |
+
gt_mask = gt_mask.to(device=pred_mask.device, dtype=pred_mask.dtype)
|
| 284 |
+
num_seg, frames = pred_mask.shape[:2]
|
| 285 |
+
weight = num_seg * frames
|
| 286 |
+
total_iou += utility.mask_iou(pred_mask.detach().float(), gt_mask.float()) * weight
|
| 287 |
+
total_fscore += utility.Eval_Fmeasure(pred_mask.detach().float(), gt_mask.float(), None) * weight
|
| 288 |
+
count += weight
|
| 289 |
+
return {
|
| 290 |
+
"miou": total_iou / max(1, count),
|
| 291 |
+
"fscore": total_fscore / max(1, count),
|
| 292 |
+
}
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def evaluate(model, loader):
|
| 296 |
+
model.eval()
|
| 297 |
+
visual_model = model.get_model().visual_model
|
| 298 |
+
total_iou = 0.0
|
| 299 |
+
total_fscore = 0.0
|
| 300 |
+
total_null = 0.0
|
| 301 |
+
count = 0
|
| 302 |
+
|
| 303 |
+
with torch.no_grad():
|
| 304 |
+
for batch in tqdm(loader, desc=f"Cached eval {args.cache_split}"):
|
| 305 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 306 |
+
batch_pred = decode_batch(visual_model, batch, "cuda")
|
| 307 |
+
for sample, pred in zip(batch, batch_pred):
|
| 308 |
+
gt = sample["gt_masks"].to(device=pred.device, dtype=pred.dtype)
|
| 309 |
+
num_seg, frames = pred.shape[:2]
|
| 310 |
+
weight = num_seg * frames
|
| 311 |
+
if args.cache_split == "test_n":
|
| 312 |
+
total_null += float(utility.metric_s_for_null(pred.float())) * weight
|
| 313 |
+
else:
|
| 314 |
+
total_iou += utility.mask_iou(pred.float(), gt.float()) * weight
|
| 315 |
+
total_fscore += utility.Eval_Fmeasure(pred.float(), gt.float(), None) * weight
|
| 316 |
+
count += weight
|
| 317 |
+
|
| 318 |
+
if count == 0:
|
| 319 |
+
raise RuntimeError("No cached samples were evaluated.")
|
| 320 |
+
|
| 321 |
+
if args.cache_split == "test_n":
|
| 322 |
+
print(f"cached valuate on test_n_refer, metric: {total_null / count}")
|
| 323 |
+
else:
|
| 324 |
+
print(
|
| 325 |
+
f"cached valuate on {args.cache_split}: "
|
| 326 |
+
f"miou: {total_iou / count} fscore: {total_fscore / count}"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
|
| 330 |
+
def train(model, loader):
|
| 331 |
+
if args.disable_gate:
|
| 332 |
+
raise ValueError("--disable_gate is only valid with --eval_only")
|
| 333 |
+
|
| 334 |
+
for p in model.parameters():
|
| 335 |
+
p.requires_grad = False
|
| 336 |
+
for name, p in model.named_parameters():
|
| 337 |
+
if "referent_gate" in name:
|
| 338 |
+
p.requires_grad = True
|
| 339 |
+
|
| 340 |
+
gate_params = [p for p in model.parameters() if p.requires_grad]
|
| 341 |
+
optimizer = AdamW(gate_params, lr=args.lr, betas=(0.9, 0.95), weight_decay=0.01)
|
| 342 |
+
stats = collect_referent_gate_stats(model)
|
| 343 |
+
print(
|
| 344 |
+
"cached gate init: "
|
| 345 |
+
f"modules={stats['modules']} "
|
| 346 |
+
f"proj_norm={stats['proj_norm']:.6f} "
|
| 347 |
+
f"gate_norm={stats['gate_norm']:.6f} "
|
| 348 |
+
f"trainable_params={sum(p.numel() for p in gate_params)}"
|
| 349 |
+
)
|
| 350 |
+
|
| 351 |
+
visual_model = model.get_model().visual_model
|
| 352 |
+
step = 0
|
| 353 |
+
for epoch in range(args.epochs):
|
| 354 |
+
model.train()
|
| 355 |
+
order_loader = loader
|
| 356 |
+
for batch in tqdm(order_loader, desc=f"Cached gate train {epoch + 1}/{args.epochs}"):
|
| 357 |
+
if args.max_steps > 0 and step >= args.max_steps:
|
| 358 |
+
break
|
| 359 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 360 |
+
pred_masks = decode_batch(visual_model, batch, "cuda")
|
| 361 |
+
gt_masks = [sample["gt_masks"] for sample in batch]
|
| 362 |
+
|
| 363 |
+
loss = compute_mask_loss(pred_masks, gt_masks)
|
| 364 |
+
optimizer.zero_grad()
|
| 365 |
+
loss.backward()
|
| 366 |
+
step += 1
|
| 367 |
+
if args.log_gate_stats_every > 0 and step % args.log_gate_stats_every == 0:
|
| 368 |
+
batch_metrics = compute_batch_metrics(pred_masks, gt_masks)
|
| 369 |
+
log_gate_stats(model, step, loss.item(), batch_metrics)
|
| 370 |
+
optimizer.step()
|
| 371 |
+
|
| 372 |
+
if args.max_steps > 0 and step >= args.max_steps:
|
| 373 |
+
print(f"stopped early at cached optimizer step {step}")
|
| 374 |
+
break
|
| 375 |
+
|
| 376 |
+
os.makedirs(args.checkpoint_root, exist_ok=True)
|
| 377 |
+
ckpt_path = os.path.join(args.checkpoint_root, f"{args.name}.pth")
|
| 378 |
+
if args.save_gate_only:
|
| 379 |
+
torch.save(
|
| 380 |
+
{
|
| 381 |
+
"type": "referent_gate_only",
|
| 382 |
+
"base_model": args.saved_model,
|
| 383 |
+
"state_dict": referent_gate_state_dict(model),
|
| 384 |
+
},
|
| 385 |
+
ckpt_path,
|
| 386 |
+
)
|
| 387 |
+
else:
|
| 388 |
+
torch.save(model.state_dict(), ckpt_path)
|
| 389 |
+
print(f"cached gate model saved as {ckpt_path}")
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def main():
|
| 393 |
+
set_seed(42)
|
| 394 |
+
random.seed(42)
|
| 395 |
+
np.random.seed(42)
|
| 396 |
+
|
| 397 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(
|
| 398 |
+
args.mllm,
|
| 399 |
+
cache_dir=None,
|
| 400 |
+
model_max_length=2048,
|
| 401 |
+
padding_side="right",
|
| 402 |
+
use_fast=False,
|
| 403 |
+
)
|
| 404 |
+
tokenizer.pad_token = tokenizer.unk_token
|
| 405 |
+
tokenizer.add_tokens("[SEG]")
|
| 406 |
+
seg_token_idx = tokenizer("[SEG]", add_special_tokens=False).input_ids[0]
|
| 407 |
+
|
| 408 |
+
dataset = CachedQDataset(args.cache_split, args)
|
| 409 |
+
if args.overfit_samples > 0:
|
| 410 |
+
n = min(args.overfit_samples, len(dataset))
|
| 411 |
+
dataset = Subset(dataset, list(range(n)))
|
| 412 |
+
print(f"cached overfit_samples enabled: using first {n} samples")
|
| 413 |
+
|
| 414 |
+
loader = DataLoader(
|
| 415 |
+
dataset,
|
| 416 |
+
batch_size=args.batch_size,
|
| 417 |
+
shuffle=not args.eval_only,
|
| 418 |
+
num_workers=4,
|
| 419 |
+
collate_fn=collate_cached,
|
| 420 |
+
)
|
| 421 |
+
|
| 422 |
+
model = build_model(tokenizer, seg_token_idx)
|
| 423 |
+
if args.gate_checkpoint:
|
| 424 |
+
load_referent_gate_checkpoint(model, args.gate_checkpoint)
|
| 425 |
+
if args.disable_gate:
|
| 426 |
+
zero_referent_gate(model)
|
| 427 |
+
print("disable_gate enabled: referent gate forced to identity")
|
| 428 |
+
|
| 429 |
+
if args.eval_only:
|
| 430 |
+
evaluate(model, loader)
|
| 431 |
+
return
|
| 432 |
+
|
| 433 |
+
train(model, loader)
|
| 434 |
+
if not args.skip_eval_after_train:
|
| 435 |
+
evaluate(model, loader)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
if __name__ == "__main__":
|
| 439 |
+
main()
|