hsq12138's picture
Upload CMGUI stage3 screen-grounded summarizer checkpoint
2f0e115 verified
#!/usr/bin/env python
"""Train rich Chinese screenshot summarization models on CMGUI-style data."""
from __future__ import annotations
import argparse
import gc
import json
import math
import os
import random
import re
import shutil
import time
from collections import deque
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.nn.parallel import DistributedDataParallel
from torch.utils.checkpoint import checkpoint
from torch.utils.data import DataLoader, Dataset, DistributedSampler
from tqdm import tqdm
from transformers import Adafactor, AutoImageProcessor, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, T5Tokenizer
from transformers.modeling_outputs import BaseModelOutput
DEFAULT_CONFIG = {
# 数据文件。正式训练使用 train_rich/valid_rich,先用 smoke 文件调通。
"train_file": "data/rich_cmgui/processed/smoke_train_rich.jsonl",
"valid_file": "data/rich_cmgui/processed/smoke_valid_rich.jsonl",
"output_dir": "runs/rich_cmgui_20260502/rich_grounded_siglip2_mt5_v1",
"init_checkpoint": "",
# 模型。V100 32GB x2 默认用 mT5-base;显存紧张再换 google/mt5-small。
"model_variant": "full",
"vision_model": "models/siglip2-base-patch16-224",
"decoder_model": "google/mt5-base",
# 输入长度。384 + 纵向 crop 更适合中文小字;如果 SigLIP2 位置插值失败,先降到 224。
"image_size": 384,
"num_vertical_crops": 3,
# 0 keeps all visual patch tokens. Set a cap for high-resolution/crop runs
# to borrow Pix2Struct-style richer screenshots without quadratic fusion blowup.
"max_visual_tokens": 0,
"max_elements": 80,
"max_element_tokens": 16,
"max_context_tokens": 64,
"context_text_format": "rich",
"context_include_screen_text": False,
"context_screen_text_items": 32,
"context_screen_text_dropout_rate": 0.0,
# mean preserves old checkpoints; direct modes expose task/app tokens directly to cross-attention.
"context_mode": "mean",
# max_target_tokens 只截断训练标签;eval_max_new_tokens 只控制验证生成长度。
"max_target_tokens": 384,
"eval_max_new_tokens": 384,
# 训练。两张 V100 32GB: torchrun --nproc_per_node=2 train_rich.py
"batch_size": 4,
"eval_batch_size": 0,
"grad_accum": 8,
"epochs": 6,
# 0 means use epochs. Set this higher for short early-stop runs so the LR
# schedule matches the longer run whose early checkpoint is being reproduced.
"scheduler_epochs": 0,
"lr_new": 1e-4,
"lr_fusion": 5e-5,
"lr_decoder": 1e-5,
"lr_ui_function_head": 0.0,
"weight_decay": 0.01,
"optimizer_name": "adamw",
"lr_scheduler_type": "linear",
"warmup_ratio": 0.05,
"fp16": True,
"amp_dtype": "auto",
"generation_loss_chunk_size": 32,
"activation_checkpointing": False,
"cuda_empty_cache_steps": 0,
"cuda_memory_fraction": 0.0,
"decoder_gradient_checkpointing": False,
"vision_gradient_checkpointing": False,
"freeze_decoder": False,
"freeze_vision": True,
"unfreeze_vision_last_ratio": 0.3,
# loss 权重。generation 是主目标,其他 loss 用来约束证据和结构。
"evidence_loss_weight": 0.2,
"section_loss_weight": 0.1,
"numeric_loss_weight": 0.1,
"ui_function_loss_weight": 0.0,
"search_function_loss_weight": 0.0,
"search_function_pos_weight": 1.0,
# checkpoint 和评估。
"save_every_steps": 1000,
"save_checkpoints": True,
"eval_every_steps": 0,
"model_selection_metric": "rich_quality_score",
"model_selection_mode": "max",
"early_stopping_patience": 0,
"early_stopping_min_delta": 0.0,
"max_train_samples": 0,
"max_valid_samples": 800,
"num_beams": 4,
"generation_no_repeat_ngram_size": 0,
"generation_repetition_penalty": 1.0,
"generation_min_new_tokens": 0,
"generation_block_extra_ids": False,
"generation_block_title_prefix": False,
"generation_force_json_start": False,
"context_summary_repair": False,
"canonicalize_targets": False,
"target_schema": "zh",
"task_intent_context": False,
"drop_bare_search_functions": False,
"structured_function_mode": "decoder",
"structured_function_threshold": 0.5,
"structured_search_threshold": 0.5,
"structured_max_functions": 12,
"structured_strict_search_candidates": False,
"structured_evidence_mode": "decoder",
"structured_evidence_threshold": 0.5,
"structured_max_evidence": 8,
"structured_evidence_fallback_top1": True,
# Vision-memory options. Defaults preserve old checkpoints/experiments.
"direct_visual_tokens": False,
"direct_element_tokens": False,
"direct_context_passthrough": False,
"include_pooled_memory": True,
"native_context_forward": False,
"disable_vision": False,
"init_resize_mismatched_non_decoder": False,
"grad_clip_strategy": "global",
"max_grad_norm": 1.0,
"function_signal_to_decoder": False,
"function_signal_scale": 1.0,
"search_signal_to_decoder": False,
"search_signal_scale": 1.0,
"visual_memory_scale": 1.0,
"element_memory_scale": 1.0,
"pooled_memory_scale": 1.0,
"decoder_memory_scale": 1.0,
"data_parallel": False,
"strict_data_checks": True,
"max_target_truncation_rate": 0.01,
"seed": 20260502,
"num_workers": 4,
}
SECTION_NAMES = ["visible_text", "interaction_data", "ui_functions", "key_ui_clues"]
def read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
with path.open("r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if line:
yield json.loads(line)
def write_json(path: Path, obj: Dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")
def append_jsonl(path: Path, obj: Dict[str, Any]) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("a", encoding="utf-8", newline="\n") as f:
f.write(json.dumps(obj, ensure_ascii=False) + "\n")
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
def normalize_model_reference(value: Any) -> str:
normalized = str(value or "").replace("\\", "/")
if "://" not in normalized:
normalized = re.sub(r"/+", "/", normalized)
return normalized
def init_distributed() -> Tuple[bool, int, int, int]:
if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
return False, 0, 1, 0
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
local_rank = int(os.environ.get("LOCAL_RANK", 0))
dist.init_process_group(backend="nccl")
torch.cuda.set_device(local_rank)
return True, rank, world_size, local_rank
def is_main(rank: int) -> bool:
return rank == 0
def unwrap_parallel_model(model: nn.Module) -> nn.Module:
if isinstance(model, (DistributedDataParallel, nn.DataParallel)):
return model.module
return model
def safe_text(value: Any) -> str:
if value is None:
return ""
return re.sub(r"\s+", " ", str(value)).strip()
def target_to_text(target: Dict[str, Any], target_schema: str = "zh") -> str:
schema = str(target_schema or "zh").lower()
if target_schema_is_summary(schema):
return safe_text(target.get("summary_zh"))
if target_schema_is_summary_visible(schema):
return target_to_summary_visible_text(target)
if target_schema_is_natural_text(schema):
return target_to_natural_text(target)
if schema in {"alias", "aliases", "en", "english"}:
payload = {
"summary_zh": safe_text(target.get("summary_zh")),
"visible_text": target.get("visible_text", [])[:16],
"interaction_data": target.get("interaction_data", [])[:12],
"ui_functions": target.get("ui_functions", [])[:12],
"key_ui_clues": target.get("key_ui_clues", [])[:12],
}
else:
payload = {
"画面总结": safe_text(target.get("summary_zh")),
"可见文字": target.get("visible_text", [])[:16],
"互动数据": target.get("interaction_data", [])[:12],
"功能入口": target.get("ui_functions", [])[:12],
"关键证据": target.get("key_ui_clues", [])[:12],
}
return json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
def target_schema_is_summary(target_schema: str) -> bool:
return str(target_schema or "").lower() in {"summary", "summary_zh", "summary-only", "summary_only"}
def target_schema_is_natural_text(target_schema: str) -> bool:
return str(target_schema or "").lower() in {
"natural_zh",
"rich_text_zh",
"zh_text",
"text_zh",
"summary_visible_zh",
"natural_summary_visible_zh",
}
def target_schema_is_summary_visible(target_schema: str) -> bool:
return str(target_schema or "").lower() in {"summary_visible_zh", "natural_summary_visible_zh"}
def format_evidence_suffix(values: Any) -> str:
evidence_ids = dedupe_texts(values or [], max_items=8)
if not evidence_ids:
return ""
return f"(证据:{'、'.join(evidence_ids)})"
def format_natural_entry(value: Any) -> str:
if isinstance(value, dict):
name = safe_text(value.get("name") or value.get("text") or value.get("label"))
detail = safe_text(value.get("value"))
if name and detail:
name = f"{name}={detail}"
elif detail and not name:
name = detail
return safe_text(name + format_evidence_suffix(value.get("evidence_ids", [])))
return safe_text(value)
def format_natural_list(values: Any, max_items: int = 16) -> str:
if not isinstance(values, list):
return "无"
entries = []
seen = set()
for value in values[:max_items]:
text = format_natural_entry(value)
if text and text not in seen:
entries.append(text)
seen.add(text)
return "、".join(entries) if entries else "无"
def target_to_natural_text(target: Dict[str, Any]) -> str:
summary = safe_text(target.get("summary_zh"))
visible = format_natural_list(target.get("visible_text", []), max_items=16)
interactions = format_natural_list(target.get("interaction_data", []), max_items=12)
functions = format_natural_list(target.get("ui_functions", []), max_items=12)
evidence = format_natural_list(target.get("key_ui_clues", []), max_items=12)
return "\n".join(
[
f"画面总结:{summary}",
f"可见文字:{visible}",
f"互动数据:{interactions}",
f"功能入口:{functions}",
f"关键证据:{evidence}",
]
)
def target_to_summary_visible_text(target: Dict[str, Any]) -> str:
summary = safe_text(target.get("summary_zh"))
visible = format_natural_list(target.get("visible_text", []), max_items=16)
interactions = format_natural_list(target.get("interaction_data", []), max_items=12)
return "\n".join(
[
f"画面总结:{summary}",
f"可见文字:{visible}",
f"互动数据:{interactions}",
]
)
JSONISH_OUTPUT_KEYS = ("画面总结", "可见文字", "互动数据", "功能入口", "关键证据")
JSONISH_ALL_KEYS = JSONISH_OUTPUT_KEYS + ("name", "value", "evidence_ids")
def normalize_evidence_id(value: Any) -> str:
text = safe_text(value)
match = re.fullmatch(r"[Ee]\s*((?:\d\s*){2,8})", text)
if match:
digits = re.sub(r"\s+", "", match.group(1))
return f"E{digits}"
return text
def repair_generated_json_text(text: str) -> str:
candidate = safe_text(text)
start = candidate.find("{")
if start >= 0:
candidate = candidate[start:]
candidate = re.sub(r"evidence\s*_\s*ids", "evidence_ids", candidate)
candidate = re.sub(
r"\b[Ee]\s*((?:\d\s*){2,8})\b",
lambda match: "E" + re.sub(r"\s+", "", match.group(1)),
candidate,
)
for key in JSONISH_ALL_KEYS:
candidate = re.sub(
r"([\{\[,])\s*\"?\s*" + re.escape(key) + r"\s*\"\s*:",
lambda match, key=key: f'{match.group(1)}"{key}":',
candidate,
)
candidate = re.sub(
r"([\{\[,])\s*\"?\s*" + re.escape(key) + r"\s*:",
lambda match, key=key: f'{match.group(1)}"{key}":',
candidate,
)
return candidate
def json_object_from_text(text: str) -> Tuple[Optional[Dict[str, Any]], bool]:
text = text.strip()
if not text:
return None, False
try:
obj = json.loads(text)
return (obj, True) if isinstance(obj, dict) else (None, False)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
try:
obj = json.loads(text[start : end + 1])
return (obj, True) if isinstance(obj, dict) else (None, False)
except json.JSONDecodeError:
pass
return None, False
def jsonish_field_payload(text: str, key: str, following_keys: Iterable[str]) -> str:
match = re.search(r'"' + re.escape(key) + r'"\s*:', text)
if not match:
return ""
tail = text[match.end() :]
end = len(tail)
for following_key in following_keys:
next_match = re.search(r',?\s*"' + re.escape(following_key) + r'"\s*:', tail)
if next_match:
end = min(end, next_match.start())
return tail[:end].strip().rstrip(",")
def clean_jsonish_string(value: Any) -> str:
text = safe_text(value).strip()
while text.startswith('"'):
text = text[1:].strip()
while text.endswith('"'):
text = text[:-1].strip()
return safe_text(text.replace('\\"', '"'))
def dedupe_texts(values: Iterable[Any], max_items: int = 64) -> List[str]:
output: List[str] = []
seen = set()
for value in values:
text = clean_jsonish_string(normalize_evidence_id(value))
if not text or text in seen:
continue
output.append(text)
seen.add(text)
if len(output) >= max_items:
break
return output
def jsonish_string_list(payload: str, max_items: int = 64) -> List[str]:
body = payload.strip()
if "[" in body:
body = body[body.find("[") + 1 :]
if "]" in body:
body = body[: body.rfind("]")]
values = [match.group(1) for match in re.finditer(r'"([^"\\]*(?:\\.[^"\\]*)*)"', body)]
if not values:
values = re.findall(r"\bE\d{2,8}\b", body)
return dedupe_texts(values, max_items=max_items)
def normalize_function_entries(values: Any) -> List[Any]:
if not isinstance(values, list):
return []
output: List[Any] = []
for value in values:
if isinstance(value, dict):
name = clean_jsonish_string(value.get("name"))
evidence_ids = dedupe_texts(value.get("evidence_ids", []) or [], max_items=8)
if name:
output.append({"name": name, "evidence_ids": evidence_ids})
else:
name = clean_jsonish_string(value)
if name:
output.append(name)
return output
def jsonish_function_entries(payload: str) -> List[Any]:
body = payload.strip()
if body.startswith("[") and body.endswith("]"):
try:
parsed = json.loads(body)
normalized = normalize_function_entries(parsed)
if normalized:
return normalized
except json.JSONDecodeError:
pass
output: List[Any] = []
object_bodies = [match.group(1) for match in re.finditer(r"\{([^{}]*)\}", body)] or [body]
for object_body in object_bodies:
name_match = re.search(r'"name"\s*:\s*"([^"\\]*(?:\\.[^"\\]*)*)"', object_body)
if not name_match:
continue
name = clean_jsonish_string(name_match.group(1))
evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", object_body), max_items=8)
if name:
output.append({"name": name, "evidence_ids": evidence_ids})
return output
def jsonish_prediction_object(text: str) -> Optional[Dict[str, Any]]:
repaired = repair_generated_json_text(text)
summary = clean_jsonish_string(
jsonish_field_payload(repaired, "画面总结", ("可见文字", "互动数据", "功能入口", "关键证据"))
)
visible_text = jsonish_string_list(
jsonish_field_payload(repaired, "可见文字", ("互动数据", "功能入口", "关键证据")),
max_items=32,
)
interaction_data = jsonish_string_list(
jsonish_field_payload(repaired, "互动数据", ("功能入口", "关键证据")),
max_items=16,
)
functions = jsonish_function_entries(jsonish_field_payload(repaired, "功能入口", ("关键证据",)))
evidence_ids = jsonish_string_list(jsonish_field_payload(repaired, "关键证据", ()), max_items=16)
if not (summary or visible_text or interaction_data or functions or evidence_ids):
return None
return {
"画面总结": summary,
"可见文字": visible_text,
"互动数据": interaction_data,
"功能入口": functions,
"关键证据": evidence_ids,
}
NATURAL_OUTPUT_KEYS = ("画面总结", "可见文字", "互动数据", "功能入口", "关键证据")
def has_natural_title_prefix(text: str) -> bool:
return re.search(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", str(text or "")) is not None
def strip_natural_title_prefix(text: str) -> Tuple[str, bool]:
raw = str(text or "").strip().replace("\r\n", "\n").replace("\r", "\n")
if not raw or not has_natural_title_prefix(raw):
return raw, False
prefix_match = re.match(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", raw)
if not prefix_match:
return raw, False
key_pattern = r"(" + "|".join(re.escape(key) for key in NATURAL_OUTPUT_KEYS) + r")\s*[::]"
key_matches = list(re.finditer(key_pattern, raw))
if not key_matches:
return re.sub(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", "", raw).strip(), True
first_key = key_matches[0]
prefix_body = safe_text(raw[prefix_match.end() : first_key.start()])
rest = raw[first_key.start() :].strip()
if first_key.group(1) == "画面总结":
return rest, True
if prefix_body:
return f"画面总结:{prefix_body}\n{rest}", True
return rest, True
def natural_field_payload(text: str, key: str, following_keys: Iterable[str]) -> str:
match = re.search(re.escape(key) + r"\s*[::]", text)
if not match:
return ""
tail = text[match.end() :]
end = len(tail)
for following_key in following_keys:
next_match = re.search(re.escape(following_key) + r"\s*[::]", tail)
if next_match:
end = min(end, next_match.start())
return tail[:end].strip()
def split_natural_items(payload: str, max_items: int = 64) -> List[str]:
payload = safe_text(payload)
if not payload or payload in {"无", "暂无", "没有", "[]"}:
return []
parts = re.split(r"[\n;;|]+|(?<!\d)、", payload)
if len(parts) == 1:
parts = re.split(r"\s*,\s*", payload)
return dedupe_texts(parts, max_items=max_items)
def natural_function_entries(payload: str) -> List[Any]:
output: List[Any] = []
for item in split_natural_items(payload, max_items=32):
evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", normalize_evidence_id(item)), max_items=8)
name = re.sub(r"(?证据[::].*?)?$", "", item).strip()
name = re.sub(r"\(\s*证据[::].*?\)\s*$", "", name).strip()
if name:
output.append({"name": name, "evidence_ids": evidence_ids})
return output
def natural_prediction_from_text(text: str) -> Dict[str, Any]:
text, _ = strip_natural_title_prefix(text)
fields = {
key: natural_field_payload(text, key, NATURAL_OUTPUT_KEYS[index + 1 :])
for index, key in enumerate(NATURAL_OUTPUT_KEYS)
}
summary = safe_text(fields.get("画面总结"))
if not summary:
first_line = re.split(r"[\n。]", text, maxsplit=1)[0]
summary = safe_text(first_line)
visible_text = split_natural_items(fields.get("可见文字", ""), max_items=32)
interaction_data = split_natural_items(fields.get("互动数据", ""), max_items=16)
functions = natural_function_entries(fields.get("功能入口", ""))
evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", normalize_evidence_id(fields.get("关键证据", ""))), max_items=16)
if not evidence_ids:
evidence_ids = dedupe_texts(
eid for func in functions if isinstance(func, dict) for eid in func.get("evidence_ids", [])
)
return {
"画面总结": summary,
"可见文字": visible_text,
"互动数据": interaction_data,
"功能入口": functions,
"关键证据": evidence_ids,
}
def safe_json_loads_with_repair(text: str) -> Tuple[Optional[Dict[str, Any]], bool, bool, bool]:
obj, strict_ok = json_object_from_text(text)
if strict_ok:
return obj, True, False, True
repaired_text = repair_generated_json_text(text)
obj, repaired_json_ok = json_object_from_text(repaired_text)
if repaired_json_ok:
return obj, True, repaired_text.strip() != str(text or "").strip(), False
obj = jsonish_prediction_object(repaired_text)
if obj is not None:
return obj, True, True, False
return None, False, False, False
def safe_json_loads(text: str) -> Tuple[Optional[Dict[str, Any]], bool]:
obj, ok, _, _ = safe_json_loads_with_repair(text)
return obj, ok
def load_seq2seq_tokenizer(model_name_or_path: str, model_hint: str = ""):
name = f"{model_name_or_path} {model_hint}".lower()
load_path = str(model_name_or_path)
path = Path(load_path)
has_spiece = path.is_dir() and (path / "spiece.model").exists()
if has_spiece or "mt5" in name or "t5" in name:
hint_path = Path(str(model_hint)) if model_hint else None
if path.is_dir() and not (path / "spiece.model").exists() and hint_path and (hint_path / "spiece.model").exists():
load_path = str(hint_path)
return T5Tokenizer.from_pretrained(load_path, fix_mistral_regex=True)
return AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False)
def char_lcs(a: str, b: str) -> int:
if not a or not b:
return 0
prev = [0] * (len(b) + 1)
for ca in a:
cur = [0]
for j, cb in enumerate(b, start=1):
if ca == cb:
cur.append(prev[j - 1] + 1)
else:
cur.append(max(cur[-1], prev[j]))
prev = cur
return prev[-1]
def rouge_l_char(pred: str, ref: str) -> float:
if not pred or not ref:
return 0.0
lcs = char_lcs(pred, ref)
prec = lcs / max(1, len(pred))
rec = lcs / max(1, len(ref))
if prec + rec == 0:
return 0.0
return 2 * prec * rec / (prec + rec)
def extract_summary(obj: Optional[Dict[str, Any]]) -> str:
if not obj:
return ""
return safe_text(obj.get("画面总结") or obj.get("summary_zh") or obj.get("summary"))
def extract_evidence_ids(obj: Optional[Dict[str, Any]]) -> List[str]:
if not obj:
return []
values = obj.get("关键证据") or obj.get("key_ui_clues") or obj.get("evidence") or []
out: List[str] = []
if isinstance(values, list):
for value in values:
if isinstance(value, str):
normalized = normalize_evidence_id(value)
if normalized:
out.append(normalized)
elif isinstance(value, dict):
out.extend(normalize_evidence_id(x) for x in value.get("evidence_ids", []) if normalize_evidence_id(x))
return out
def extract_function_entries(obj: Optional[Dict[str, Any]]) -> List[Any]:
if not isinstance(obj, dict):
return []
values = obj.get("功能入口") or obj.get("ui_functions") or []
return values if isinstance(values, list) else []
def extract_function_names(obj: Optional[Dict[str, Any]]) -> List[str]:
names = []
for value in extract_function_entries(obj):
if isinstance(value, dict):
names.append(safe_text(value.get("name")))
else:
names.append(safe_text(value))
return [name for name in names if name]
def extract_function_evidence_ids(target: Dict[str, Any]) -> List[str]:
ids: List[str] = []
for value in extract_function_entries(target):
if isinstance(value, dict):
ids.extend(normalize_evidence_id(eid) for eid in value.get("evidence_ids", []) if normalize_evidence_id(eid))
return ids
def extract_named_function_evidence_ids(target: Dict[str, Any], keyword: str) -> List[str]:
ids: List[str] = []
for value in extract_function_entries(target):
if isinstance(value, dict) and keyword in safe_text(value.get("name")):
ids.extend(normalize_evidence_id(eid) for eid in value.get("evidence_ids", []) if normalize_evidence_id(eid))
return ids
def has_search_function(obj: Optional[Dict[str, Any]]) -> bool:
return any("搜索" in name for name in extract_function_names(obj))
def count_search_functions(obj: Optional[Dict[str, Any]]) -> int:
return sum(1 for name in extract_function_names(obj) if "搜索" in name)
def count_bare_search_functions(obj: Optional[Dict[str, Any]]) -> int:
return sum(1 for name in extract_function_names(obj) if is_bare_search_function_name(name))
def normalized_prediction_obj(obj: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
if not isinstance(obj, dict):
return None
return {
"画面总结": safe_text(obj.get("画面总结") or obj.get("summary_zh") or obj.get("summary")),
"可见文字": obj.get("可见文字") or obj.get("visible_text") or [],
"互动数据": obj.get("互动数据") or obj.get("interaction_data") or [],
"功能入口": obj.get("功能入口") or obj.get("ui_functions") or [],
"关键证据": obj.get("关键证据") or obj.get("key_ui_clues") or obj.get("evidence") or [],
}
def element_id(item: Dict[str, Any]) -> str:
return safe_text(item.get("id") or item.get("ocr_id"))
def visible_text_from_row(row: Dict[str, Any], max_items: int = 16) -> List[str]:
values: List[str] = []
seen = set()
for group_name in ("ocr_items", "ui_items"):
for item in row.get(group_name, []) or []:
text = safe_text(item.get("text"))
if not text or text in seen or re.fullmatch(r"E\d+", text):
continue
if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z0-9]", text):
continue
if re.fullmatch(r"[\s\-_.::/]+", text):
continue
values.append(text[:80])
seen.add(text)
if len(values) >= max_items:
return values
return values
def visible_texts_from_items(items: Iterable[Dict[str, Any]], max_items: int = 16) -> List[str]:
values: List[str] = []
seen = set()
for item in items or []:
text = safe_text(item.get("text"))
if not text or text in seen or re.fullmatch(r"E\d+", text):
continue
if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z0-9]", text):
continue
if re.fullmatch(r"[\s\-_.::/]+", text):
continue
values.append(text[:80])
seen.add(text)
if len(values) >= max_items:
break
return values
def maybe_dropout_texts(values: List[str], dropout_rate: float) -> List[str]:
if dropout_rate <= 0.0 or not values:
return values
kept = [value for value in values if random.random() >= dropout_rate]
if kept:
return kept
return [random.choice(values)]
def build_context_text(row: Dict[str, Any], args: argparse.Namespace, screen_text_dropout_rate: float = 0.0) -> str:
app_text = safe_text(row.get("app"))
instruction_text = safe_text(row.get("instruction"))
context_format = str(getattr(args, "context_text_format", "rich") or "rich").lower()
if context_format == "text_only":
parts = [f"app: {app_text}", f"task: {instruction_text}"]
if bool(getattr(args, "context_include_screen_text", False)):
max_screen_items = int(getattr(args, "context_screen_text_items", 32) or 32)
ocr_text = visible_texts_from_items(row.get("ocr_items") or [], max_items=max_screen_items)
ui_text = visible_texts_from_items(row.get("ui_items") or [], max_items=max_screen_items)
ocr_text = maybe_dropout_texts(ocr_text, screen_text_dropout_rate)
ui_text = maybe_dropout_texts(ui_text, screen_text_dropout_rate)
if ocr_text:
parts.append("ocr: " + " | ".join(ocr_text))
if ui_text:
parts.append("ui: " + " | ".join(ui_text))
return "\n".join(parts)
context_text = f"应用:{app_text} 任务:{instruction_text}"
if bool(getattr(args, "context_include_screen_text", False)):
screen_text = visible_text_from_row(row, max_items=int(getattr(args, "context_screen_text_items", 32) or 32))
screen_text = maybe_dropout_texts(screen_text, screen_text_dropout_rate)
if screen_text:
context_text = f"{context_text} 屏幕文字:{' | '.join(screen_text)}"
return context_text
def prediction_from_summary(row: Dict[str, Any], summary_text: str) -> Dict[str, Any]:
# decoder 在 summary 模式下只生成 "画面总结",其它字段必须留空,
# 任何来自输入行(OCR/UI 项)的回填都会污染评估指标,让 rouge 之外的字段失真。
# 结构化字段如需填充,必须由后续 apply_structured_*_predictions(来自辅助 head)显式写入。
return {
"画面总结": safe_text(summary_text),
"可见文字": [],
"互动数据": [],
"功能入口": [],
"关键证据": [],
}
def repair_prediction_with_context(row: Dict[str, Any], pred_obj: Optional[Dict[str, Any]]) -> Tuple[Dict[str, Any], bool]:
app = safe_text(row.get("app")) or "移动应用"
instruction = safe_text(row.get("instruction"))
has_search_task = row_has_search_task(row)
has_search_evidence = row_has_visible_search_evidence(row)
repaired = normalized_prediction_obj(pred_obj)
changed = repaired is None
if repaired is None:
repaired = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []}
summary = safe_text(repaired.get("画面总结"))
search_overuse = ("搜索页面" in summary or "搜索结果" in summary) and "搜索" not in instruction
generic_app = "App的App" in summary or "手机App的App" in summary or summary.count("App") >= 3
needs_summary = not summary or app not in summary or generic_app or search_overuse
if needs_summary:
if instruction:
summary = f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。 当前任务语境是:{instruction}。"
else:
summary = f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。"
repaired["画面总结"] = summary
changed = True
visible = repaired.get("可见文字") or []
if isinstance(visible, list):
cleaned_visible = [value for value in visible if not re.fullmatch(r"E\d+", safe_text(value))][:16]
if cleaned_visible != visible:
repaired["可见文字"] = cleaned_visible
changed = True
else:
repaired["可见文字"] = []
changed = True
if not isinstance(repaired.get("互动数据"), list):
repaired["互动数据"] = []
changed = True
functions = repaired.get("功能入口") or []
if isinstance(functions, list):
cleaned_functions = []
for function in functions:
name = safe_text(function.get("name") if isinstance(function, dict) else function)
if is_generic_function_name(name, row):
changed = True
continue
if "搜索" in name and not (has_search_task and has_search_evidence):
changed = True
continue
cleaned_functions.append(function)
if cleaned_functions != functions:
repaired["功能入口"] = cleaned_functions
else:
repaired["功能入口"] = []
changed = True
evidence = repaired.get("关键证据") or []
if not isinstance(evidence, list):
evidence = []
if not evidence:
evidence = list(row.get("weak_evidence_ids") or [])[:8]
repaired["关键证据"] = evidence
changed = True
return repaired, changed
def row_has_visible_search_evidence(row: Dict[str, Any]) -> bool:
for item in list(row.get("ui_items", [])) + list(row.get("ocr_items", [])):
text = safe_text(item.get("text"))
item_type = safe_text(item.get("type")).lower()
if "搜索" in text or "search" in item_type:
return True
if text in {"搜", "搜索框"} or "输入" in text:
return True
return False
def row_has_search_context(row: Dict[str, Any]) -> bool:
instruction = safe_text(row.get("instruction"))
screen_text = "".join(safe_text(item.get("text")) for item in (row.get("ui_items") or []))
screen_text += "".join(safe_text(item.get("text")) for item in (row.get("ocr_items") or []))
return "搜索" in instruction or "搜索" in screen_text
def row_has_search_task(row: Dict[str, Any]) -> bool:
instruction = safe_text(row.get("instruction"))
return "搜索" in instruction or "查找" in instruction or "搜" in instruction
def is_generic_function_name(name: str, row: Dict[str, Any]) -> bool:
text = safe_text(name)
app = safe_text(row.get("app"))
if not text:
return True
generic_names = {
"入口",
"功能入口",
"功能反馈",
"打开应用",
"进入应用",
"打开App",
"进入App",
}
if text in generic_names:
return True
if app and text in {f"{app}入口", f"进入{app}", f"打开{app}", f"进入{app}App", f"打开{app}App"}:
return True
if re.fullmatch(r"(进入|打开).{1,12}(App|应用)?", text):
return True
return False
def is_bare_search_function_name(name: str) -> bool:
return safe_text(name) == "搜索"
def is_search_ui_item(item: Dict[str, Any]) -> bool:
text = safe_text(item.get("text"))
item_type = safe_text(item.get("type")).lower()
if "搜索" in text or "search" in item_type:
return True
return text in {"搜", "搜索框"} or "输入" in text
def is_structured_search_candidate_item(item: Dict[str, Any], strict: bool = False) -> bool:
if not is_search_ui_item(item):
return False
if not strict:
return True
text = safe_text(item.get("text"))
if text in {"搜索", "搜", "搜索框"}:
return True
passive_terms = [
"历史搜索",
"搜索历史",
"搜索发现",
"深度搜索",
"AI搜索",
"热门搜索",
"最近搜索",
"相关搜索",
"搜索记录",
"清除搜索记录",
"搜索结果",
"搜索来源",
"搜索推荐",
"根据你的搜索",
]
if any(term in text for term in passive_terms):
return False
control_terms = [
"搜索框",
"搜索、提问",
"搜索地点",
"搜索店内",
"搜索商品",
"搜索景点",
"搜索机票",
"附近搜索",
"请输入",
"输入",
]
if any(term in text for term in control_terms):
return True
return text.startswith("搜索") and len(text) <= 8
def structured_function_name(item: Dict[str, Any], row: Dict[str, Any]) -> str:
text = safe_text(item.get("text"))
item_type = safe_text(item.get("type")).lower()
if is_search_ui_item(item):
return "搜索功能入口"
if text:
if len(text) <= 12:
return text
if "购物车" in text:
return "购物车"
if "消息" in text:
return "消息入口"
if "设置" in text:
return "设置入口"
return text[:12]
if item.get("is_action_target"):
instruction = safe_text(row.get("instruction"))
for keyword in ["商城", "购物车", "消息", "订单", "收藏", "关注", "分享", "返回"]:
if keyword in instruction:
return f"{keyword}入口" if keyword not in {"收藏", "关注", "分享", "返回"} else keyword
if "button" in item_type:
return "按钮入口"
return "功能入口"
def apply_structured_function_predictions(
row: Dict[str, Any],
pred_obj: Optional[Dict[str, Any]],
function_scores: torch.Tensor,
search_scores: torch.Tensor,
args: argparse.Namespace,
) -> Dict[str, Any]:
mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower()
if mode == "decoder":
return pred_obj if isinstance(pred_obj, dict) else normalized_prediction_obj(pred_obj)
structured = normalized_prediction_obj(pred_obj)
if structured is None:
structured = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []}
function_threshold = float(getattr(args, "structured_function_threshold", 0.5) or 0.5)
search_threshold = float(getattr(args, "structured_search_threshold", function_threshold) or function_threshold)
max_functions = int(getattr(args, "structured_max_functions", 12) or 12)
has_search_task = row_has_search_task(row)
has_search_evidence = row_has_visible_search_evidence(row)
strict_search_candidates = bool(getattr(args, "structured_strict_search_candidates", False))
items = (row.get("ui_items") or [])[: int(getattr(args, "max_elements", len(function_scores)))]
candidates: List[Tuple[float, Dict[str, Any]]] = []
for idx, item in enumerate(items[: len(function_scores)]):
function_score = float(function_scores[idx])
search_score = float(search_scores[idx]) if idx < len(search_scores) else 0.0
raw_item_is_search = is_search_ui_item(item)
item_is_search = is_structured_search_candidate_item(item, strict=strict_search_candidates)
if raw_item_is_search and not item_is_search:
continue
if item_is_search:
keep = has_search_task and has_search_evidence and search_score >= search_threshold
else:
keep = function_score >= function_threshold
if not keep:
continue
evidence_id = safe_text(item.get("id") or item.get("ocr_id"))
if not evidence_id:
continue
name = structured_function_name(item, row)
if is_generic_function_name(name, row):
continue
if "搜索" in name and not (has_search_task and has_search_evidence):
continue
score = search_score if item_is_search else function_score
candidates.append((score, {"name": name, "evidence_ids": [evidence_id]}))
candidates.sort(key=lambda value: value[0], reverse=True)
functions: List[Dict[str, Any]] = []
seen = set()
for _, function in candidates:
key = (function["name"], tuple(function.get("evidence_ids", [])))
if key in seen:
continue
seen.add(key)
functions.append(function)
if len(functions) >= max_functions:
break
structured["功能入口"] = functions
return structured
def apply_structured_evidence_predictions(
row: Dict[str, Any],
pred_obj: Optional[Dict[str, Any]],
evidence_scores: torch.Tensor,
args: argparse.Namespace,
) -> Dict[str, Any]:
mode = str(getattr(args, "structured_evidence_mode", "decoder") or "decoder").lower()
structured = normalized_prediction_obj(pred_obj)
if structured is None:
structured = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []}
if mode == "decoder":
return structured
if mode != "heads":
raise ValueError("structured_evidence_mode must be one of: decoder, heads")
threshold = float(getattr(args, "structured_evidence_threshold", 0.5) or 0.5)
max_evidence = int(getattr(args, "structured_max_evidence", 8) or 8)
fallback_top1 = bool(getattr(args, "structured_evidence_fallback_top1", True))
items = (row.get("ui_items") or [])[: int(getattr(args, "max_elements", len(evidence_scores)))]
candidates: List[Tuple[float, str]] = []
for idx, item in enumerate(items[: len(evidence_scores)]):
eid = element_id(item)
if not eid:
continue
score = float(evidence_scores[idx])
candidates.append((score, eid))
candidates.sort(key=lambda value: value[0], reverse=True)
selected: List[str] = []
seen = set()
for score, eid in candidates:
if score < threshold:
continue
if eid in seen:
continue
selected.append(eid)
seen.add(eid)
if len(selected) >= max_evidence:
break
if not selected and fallback_top1 and candidates:
selected.append(candidates[0][1])
structured["关键证据"] = selected
return structured
def build_context_summary(row: Dict[str, Any]) -> str:
app = safe_text(row.get("app")) or "移动应用"
instruction = safe_text(row.get("instruction"))
if instruction:
return f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。 当前任务语境是:{instruction}。"
return f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。"
def canonicalize_target_with_context(
row: Dict[str, Any],
target: Dict[str, Any],
drop_bare_search_functions: bool = False,
) -> Dict[str, Any]:
canonical = dict(target or {})
# 只在 summary_zh 缺失时回填模板,绝对不要覆盖已有标签——
# 之前无条件覆盖会把任何真实摘要替换成 "这是一个X界面...当前任务语境是..." 模板,
# 是导致全量训练目标 100% 同质化、模型只学到模板的根因之一。
existing_summary = safe_text(canonical.get("summary_zh"))
if not existing_summary:
canonical["summary_zh"] = build_context_summary(row)
visible = canonical.get("visible_text") or []
if isinstance(visible, list):
canonical["visible_text"] = [value for value in visible if not re.fullmatch(r"E\d+", safe_text(value))][:16]
else:
canonical["visible_text"] = []
functions = canonical.get("ui_functions") or []
if isinstance(functions, list):
has_search_task = row_has_search_task(row)
has_search_evidence = row_has_visible_search_evidence(row)
cleaned_functions = []
for function in functions:
name = safe_text(function.get("name") if isinstance(function, dict) else function)
if is_generic_function_name(name, row):
continue
if drop_bare_search_functions and is_bare_search_function_name(name):
continue
if "搜索" in name and not (has_search_task and has_search_evidence):
continue
cleaned_functions.append(function)
canonical["ui_functions"] = cleaned_functions[:12]
else:
canonical["ui_functions"] = []
if not isinstance(canonical.get("interaction_data"), list):
canonical["interaction_data"] = []
evidence = canonical.get("key_ui_clues") or []
if not isinstance(evidence, list):
evidence = []
if not evidence:
evidence = list(row.get("weak_evidence_ids") or [])[:8]
canonical["key_ui_clues"] = evidence[:12]
return canonical
class RichScreenshotDataset(Dataset):
def __init__(self, path: str, max_samples: int = 0, sample_seed: Optional[int] = None):
self.path = Path(path)
self.rows = list(read_jsonl(self.path))
if max_samples and max_samples < len(self.rows):
if sample_seed is None:
self.rows = self.rows[:max_samples]
else:
rng = random.Random(sample_seed)
indices = sorted(rng.sample(range(len(self.rows)), max_samples))
self.rows = [self.rows[index] for index in indices]
if not self.rows:
raise ValueError(f"No rows loaded from {path}")
def __len__(self) -> int:
return len(self.rows)
def __getitem__(self, idx: int) -> Dict[str, Any]:
return self.rows[idx]
def dataset_diagnostics(rows: List[Dict[str, Any]]) -> Dict[str, Any]:
ocr_counts: List[int] = []
ui_counts: List[int] = []
target_chars: List[int] = []
missing_target = 0
short_summary = 0
missing_image_path = 0
missing_image_file = 0
missing_ocr_items = 0
missing_ui_items = 0
for row in rows:
target = row.get("target")
if not isinstance(target, dict):
missing_target += 1
summary = ""
else:
summary = safe_text(target.get("summary_zh"))
target_chars.append(len(summary))
if len(summary) < 8:
short_summary += 1
image_path = safe_text(row.get("image_path"))
if not image_path:
missing_image_path += 1
elif not Path(image_path).exists():
missing_image_file += 1
ocr_count = len(row.get("ocr_items") or [])
ui_count = len(row.get("ui_items") or [])
ocr_counts.append(ocr_count)
ui_counts.append(ui_count)
if ocr_count <= 0:
missing_ocr_items += 1
if ui_count <= 0:
missing_ui_items += 1
row_count = len(rows)
def mean(values: List[int]) -> float:
return float(np.mean(values)) if values else 0.0
return {
"rows": row_count,
"missing_target": missing_target,
"short_summary": short_summary,
"missing_image_path": missing_image_path,
"missing_image_file": missing_image_file,
"missing_ocr_items": missing_ocr_items,
"missing_ui_items": missing_ui_items,
"missing_ocr_rate": missing_ocr_items / max(1, row_count),
"missing_ui_rate": missing_ui_items / max(1, row_count),
"ocr_items_mean": mean(ocr_counts),
"ocr_items_max": int(max(ocr_counts)) if ocr_counts else 0,
"ui_items_mean": mean(ui_counts),
"ui_items_max": int(max(ui_counts)) if ui_counts else 0,
"summary_chars_mean": mean(target_chars),
"summary_chars_max": int(max(target_chars)) if target_chars else 0,
}
def validate_dataset_for_training(split_name: str, diagnostics: Dict[str, Any], args: argparse.Namespace) -> None:
if not bool(getattr(args, "strict_data_checks", True)):
return
errors: List[str] = []
if diagnostics["missing_target"]:
errors.append(f"{diagnostics['missing_target']} rows are missing target")
if diagnostics["short_summary"]:
errors.append(f"{diagnostics['short_summary']} rows have summary_zh shorter than 8 chars")
if diagnostics["missing_image_path"]:
errors.append(f"{diagnostics['missing_image_path']} rows are missing image_path")
vision_enabled = not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only"
if vision_enabled and diagnostics["missing_image_file"]:
errors.append(f"{diagnostics['missing_image_file']} image files do not exist")
needs_ui_context = bool(getattr(args, "direct_element_tokens", False)) or int(getattr(args, "max_elements", 0) or 0) > 0
needs_screen_text = bool(getattr(args, "context_include_screen_text", False))
if needs_ui_context and diagnostics["missing_ui_rate"] > 0.05:
errors.append(f"{diagnostics['missing_ui_items']} rows are missing ui_items")
if needs_screen_text and diagnostics["missing_ocr_rate"] > 0.05:
errors.append(f"{diagnostics['missing_ocr_items']} rows are missing ocr_items")
if errors:
raise ValueError(f"{split_name} dataset failed strict data checks: " + "; ".join(errors))
def token_length_summary(values: List[int], max_length: int) -> Dict[str, Any]:
if not values:
return {
"count": 0,
"mean": 0.0,
"p50": 0,
"p90": 0,
"p95": 0,
"p99": 0,
"max": 0,
"over_max": 0,
"over_max_rate": 0.0,
"at_or_over_max": 0,
"at_or_over_max_rate": 0.0,
}
ordered = sorted(values)
def pct(percent: float) -> int:
index = int(round((len(ordered) - 1) * percent / 100.0))
return int(ordered[max(0, min(len(ordered) - 1, index))])
over_max = sum(value > max_length for value in values)
at_or_over_max = sum(value >= max_length for value in values)
return {
"count": len(values),
"mean": float(np.mean(values)),
"p50": pct(50),
"p90": pct(90),
"p95": pct(95),
"p99": pct(99),
"max": int(max(values)),
"configured_max": int(max_length),
"over_max": int(over_max),
"over_max_rate": over_max / max(1, len(values)),
"at_or_over_max": int(at_or_over_max),
"at_or_over_max_rate": at_or_over_max / max(1, len(values)),
}
def tokenizer_diagnostics(rows: List[Dict[str, Any]], tokenizer, args: argparse.Namespace) -> Dict[str, Any]:
target_lengths: List[int] = []
context_lengths: List[int] = []
target_schema = getattr(args, "target_schema", "zh")
max_target_tokens = int(getattr(args, "max_target_tokens", 384) or 384)
max_context_tokens = int(getattr(args, "max_context_tokens", 64) or 64)
for row in rows:
target = row.get("target") or {}
if bool(getattr(args, "canonicalize_targets", False)):
target = canonicalize_target_with_context(
row,
target,
drop_bare_search_functions=bool(getattr(args, "drop_bare_search_functions", False)),
)
target_text = target_to_text(target, target_schema)
context_text = build_context_text(row, args)
target_lengths.append(len(tokenizer(target_text, add_special_tokens=True).input_ids))
context_lengths.append(len(tokenizer(context_text, add_special_tokens=True).input_ids))
return {
"target_tokens": token_length_summary(target_lengths, max_target_tokens),
"context_tokens": token_length_summary(context_lengths, max_context_tokens),
}
def validate_token_lengths(split_name: str, diagnostics: Dict[str, Any], args: argparse.Namespace) -> None:
if not bool(getattr(args, "strict_data_checks", True)):
return
max_allowed_rate = float(getattr(args, "max_target_truncation_rate", 0.01) or 0.0)
target_stats = diagnostics.get("target_tokens") or {}
over_rate = float(target_stats.get("over_max_rate", 0.0) or 0.0)
if over_rate > max_allowed_rate:
raise ValueError(
f"{split_name} target token truncation rate {over_rate:.2%} exceeds "
f"max_target_truncation_rate={max_allowed_rate:.2%}; "
f"max_target_tokens={target_stats.get('configured_max')}, "
f"p95={target_stats.get('p95')}, p99={target_stats.get('p99')}, max={target_stats.get('max')}"
)
if int(getattr(args, "eval_max_new_tokens", 0) or 0) < int(getattr(args, "max_target_tokens", 0) or 0):
raise ValueError("eval_max_new_tokens must be >= max_target_tokens for natural_zh validation.")
class RichCollator:
def __init__(self, tokenizer, image_processor, args: argparse.Namespace, is_training: bool = False):
self.tokenizer = tokenizer
self.image_processor = image_processor
self.args = args
self.is_training = is_training
self.use_vision = not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only"
def _load_images(self, image_path: str) -> List[Image.Image]:
path = Path(image_path)
try:
image = Image.open(path).convert("RGB")
except Exception:
image = Image.new("RGB", (self.args.image_size, self.args.image_size), color=(245, 245, 245))
images = [image]
if self.args.num_vertical_crops > 0:
width, height = image.size
crops = self.args.num_vertical_crops
for i in range(crops):
top = int(height * i / crops)
bottom = int(height * (i + 1) / crops)
images.append(image.crop((0, top, width, bottom)))
return images
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
bsz = len(batch)
if self.use_vision:
flat_images: List[Image.Image] = []
crop_counts: List[int] = []
for row in batch:
images = self._load_images(row.get("image_path", ""))
crop_counts.append(len(images))
flat_images.extend(images)
image_kwargs = {
"return_tensors": "pt",
"do_resize": True,
"size": {"height": self.args.image_size, "width": self.args.image_size},
}
pixel_values = self.image_processor(images=flat_images, **image_kwargs)["pixel_values"]
crops = crop_counts[0]
pixel_values = pixel_values.view(bsz, crops, *pixel_values.shape[1:])
else:
pixel_values = torch.empty(bsz, 0, 3, self.args.image_size, self.args.image_size, dtype=torch.float32)
all_elements: List[List[Dict[str, Any]]] = []
flat_element_texts: List[str] = []
element_mask = torch.zeros(bsz, self.args.max_elements, dtype=torch.bool)
bboxes = torch.zeros(bsz, self.args.max_elements, 4, dtype=torch.float32)
type_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long)
source_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long)
loc_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long)
confs = torch.zeros(bsz, self.args.max_elements, 1, dtype=torch.float32)
action_flags = torch.zeros(bsz, self.args.max_elements, 1, dtype=torch.float32)
evidence_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32)
numeric_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32)
ui_function_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32)
search_function_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32)
type_vocab: Dict[str, int] = {"visual": 0, "text": 1, "button": 2, "label": 3, "text_number": 4}
source_vocab: Dict[str, int] = {"unknown": 0, "cmgui": 1, "ocr": 2, "rule": 3}
loc_vocab: Dict[str, int] = {
"unknown": 0,
"top-left": 1,
"top-center": 2,
"top-right": 3,
"middle-left": 4,
"middle-center": 5,
"middle-right": 6,
"bottom-left": 7,
"bottom-center": 8,
"bottom-right": 9,
}
target_texts: List[str] = []
context_texts: List[str] = []
section_labels = torch.zeros(bsz, len(SECTION_NAMES), dtype=torch.float32)
for row_idx, row in enumerate(batch):
target = row.get("target") or {}
if bool(getattr(self.args, "canonicalize_targets", False)):
target = canonicalize_target_with_context(
row,
target,
drop_bare_search_functions=bool(getattr(self.args, "drop_bare_search_functions", False)),
)
target_texts.append(target_to_text(target, getattr(self.args, "target_schema", "zh")))
screen_text_dropout_rate = (
float(getattr(self.args, "context_screen_text_dropout_rate", 0.0) or 0.0) if self.is_training else 0.0
)
context_texts.append(build_context_text(row, self.args, screen_text_dropout_rate=screen_text_dropout_rate))
evidence_ids = set(target.get("key_ui_clues", []) or row.get("weak_evidence_ids", []))
function_evidence = set(extract_function_evidence_ids(target))
search_function_evidence = set(extract_named_function_evidence_ids(target, "搜索"))
interaction_evidence = set()
for interaction in target.get("interaction_data", []) or []:
interaction_evidence.update(interaction.get("evidence_ids", []) or [])
for sec_idx, name in enumerate(SECTION_NAMES):
if target.get(name):
section_labels[row_idx, sec_idx] = 1.0
if bool(getattr(self.args, "task_intent_context", False)):
task_type = "检索" if row_has_search_task(row) else "常规"
context_texts[-1] = f"{context_texts[-1]} 任务类别:{task_type}"
elements = (row.get("ui_items") or [])[: self.args.max_elements]
all_elements.append(elements)
for elem_idx in range(self.args.max_elements):
if elem_idx < len(elements):
elem = elements[elem_idx]
text = safe_text(elem.get("text"))
flat_element_texts.append(f"{elem.get('type','text')} {elem.get('source','')} {text}")
element_mask[row_idx, elem_idx] = True
bbox = elem.get("bbox") or [0, 0, 1, 1]
if len(bbox) == 4:
bboxes[row_idx, elem_idx] = torch.tensor(bbox, dtype=torch.float32)
type_ids[row_idx, elem_idx] = type_vocab.get(safe_text(elem.get("type")), 1)
source_ids[row_idx, elem_idx] = source_vocab.get(safe_text(elem.get("source")), 0)
loc_ids[row_idx, elem_idx] = loc_vocab.get(safe_text(elem.get("location")), 0)
confs[row_idx, elem_idx, 0] = float(elem.get("ocr_conf", elem.get("conf", 1.0)) or 1.0)
action_flags[row_idx, elem_idx, 0] = 1.0 if elem.get("is_action_target") else 0.0
if elem.get("id") in evidence_ids or elem.get("ocr_id") in evidence_ids:
evidence_labels[row_idx, elem_idx] = 1.0
if elem.get("id") in function_evidence or elem.get("ocr_id") in function_evidence:
ui_function_labels[row_idx, elem_idx] = 1.0
if elem.get("id") in search_function_evidence or elem.get("ocr_id") in search_function_evidence:
search_function_labels[row_idx, elem_idx] = 1.0
if elem.get("id") in interaction_evidence or re.search(r"\d", text):
numeric_labels[row_idx, elem_idx] = 1.0
else:
flat_element_texts.append("")
elem_tokens = self.tokenizer(
flat_element_texts,
padding=True,
truncation=True,
max_length=self.args.max_element_tokens,
return_tensors="pt",
)
elem_input_ids = elem_tokens.input_ids.view(bsz, self.args.max_elements, -1)
elem_attention_mask = elem_tokens.attention_mask.view(bsz, self.args.max_elements, -1)
context_tokens = self.tokenizer(
context_texts,
padding=True,
truncation=True,
max_length=self.args.max_context_tokens,
return_tensors="pt",
)
target_tokens = self.tokenizer(
target_texts,
padding=True,
truncation=True,
max_length=self.args.max_target_tokens,
return_tensors="pt",
)
labels = target_tokens.input_ids
labels[labels == self.tokenizer.pad_token_id] = -100
return {
"rows": batch,
"pixel_values": pixel_values,
"element_input_ids": elem_input_ids,
"element_attention_mask": elem_attention_mask,
"element_mask": element_mask,
"bboxes": bboxes,
"type_ids": type_ids,
"source_ids": source_ids,
"loc_ids": loc_ids,
"confs": confs,
"action_flags": action_flags,
"context_input_ids": context_tokens.input_ids,
"context_attention_mask": context_tokens.attention_mask,
"labels": labels,
"evidence_labels": evidence_labels,
"ui_function_labels": ui_function_labels,
"search_function_labels": search_function_labels,
"section_labels": section_labels,
"numeric_labels": numeric_labels,
}
class BottleneckPooler(nn.Module):
def __init__(self, hidden_size: int, num_queries: int = 64, num_heads: int = 8):
super().__init__()
self.queries = nn.Parameter(torch.randn(num_queries, hidden_size) * 0.02)
self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
self.norm = nn.LayerNorm(hidden_size)
def forward(self, memory: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
bsz = memory.size(0)
queries = self.queries.unsqueeze(0).expand(bsz, -1, -1)
pooled, _ = self.attn(queries, memory, memory, key_padding_mask=key_padding_mask)
return self.norm(pooled + queries)
class RichGroundedModel(nn.Module):
def __init__(self, args: argparse.Namespace):
super().__init__()
self.args = args
self.model_variant = args.model_variant
self.use_vision = not bool(getattr(args, "disable_vision", False)) and self.model_variant != "annotation_only"
self.vision = AutoModel.from_pretrained(args.vision_model) if self.use_vision else None
self.decoder = AutoModelForSeq2SeqLM.from_pretrained(args.decoder_model)
if bool(getattr(args, "decoder_gradient_checkpointing", False)) and hasattr(self.decoder, "gradient_checkpointing_enable"):
self.decoder.gradient_checkpointing_enable()
if hasattr(self.decoder.config, "use_cache"):
self.decoder.config.use_cache = False
if bool(getattr(args, "freeze_decoder", False)):
for param in self.decoder.parameters():
param.requires_grad = False
if self.vision is not None and bool(getattr(args, "vision_gradient_checkpointing", False)) and hasattr(self.vision, "gradient_checkpointing_enable"):
self.vision.gradient_checkpointing_enable()
self.hidden_size = self.decoder.config.d_model
vision_hidden = self._vision_hidden_size() if self.vision is not None else self.hidden_size
self.visual_proj = nn.Linear(vision_hidden, self.hidden_size)
self.elem_text_proj = nn.Linear(self.hidden_size, self.hidden_size)
self.type_emb = nn.Embedding(16, self.hidden_size)
self.source_emb = nn.Embedding(8, self.hidden_size)
self.loc_emb = nn.Embedding(16, self.hidden_size)
self.bbox_proj = nn.Sequential(nn.Linear(6, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, self.hidden_size))
self.roi_proj = nn.Linear(vision_hidden, self.hidden_size)
self.context_proj = nn.Linear(self.hidden_size, self.hidden_size)
enc_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=8,
dim_feedforward=self.hidden_size * 4,
dropout=0.1,
batch_first=True,
norm_first=False,
)
self.layout_encoder = nn.TransformerEncoder(enc_layer, num_layers=2)
fusion_layer = nn.TransformerEncoderLayer(
d_model=self.hidden_size,
nhead=8,
dim_feedforward=self.hidden_size * 4,
dropout=0.1,
batch_first=True,
norm_first=False,
)
self.fusion = nn.TransformerEncoder(fusion_layer, num_layers=2)
self.pooler = BottleneckPooler(self.hidden_size, num_queries=args.bottleneck_queries, num_heads=8)
self.evidence_head = nn.Linear(self.hidden_size, 1)
self.ui_function_head = nn.Linear(self.hidden_size, 1)
self.search_function_head = nn.Linear(self.hidden_size, 1)
self.function_signal_proj = nn.Linear(1, self.hidden_size)
self.search_signal_proj = nn.Linear(1, self.hidden_size)
nn.init.zeros_(self.function_signal_proj.weight)
nn.init.zeros_(self.function_signal_proj.bias)
nn.init.zeros_(self.search_signal_proj.weight)
nn.init.zeros_(self.search_signal_proj.bias)
self.numeric_head = nn.Linear(self.hidden_size, 1)
self.section_head = nn.Linear(self.hidden_size, len(SECTION_NAMES))
if self.vision is not None and args.freeze_vision:
self.freeze_vision(args.unfreeze_vision_last_ratio)
def _vision_hidden_size(self) -> int:
if self.vision is None:
return self.hidden_size
config = self.vision.config
if hasattr(config, "vision_config"):
return int(config.vision_config.hidden_size)
return int(getattr(config, "hidden_size"))
def freeze_vision(self, unfreeze_last_ratio: float) -> None:
if self.vision is None:
return
for param in self.vision.parameters():
param.requires_grad = False
if unfreeze_last_ratio <= 0:
return
layers = None
vision_model = getattr(self.vision, "vision_model", self.vision)
encoder = getattr(vision_model, "encoder", None)
if encoder is not None:
layers = getattr(encoder, "layers", None) or getattr(encoder, "layer", None)
if layers:
keep = max(1, int(len(layers) * unfreeze_last_ratio))
for layer in layers[-keep:]:
for param in layer.parameters():
param.requires_grad = True
def use_native_context_forward(self) -> bool:
return bool(getattr(self.args, "native_context_forward", False))
def mean_embed(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
emb = self.decoder.get_input_embeddings()(input_ids)
mask = attention_mask.unsqueeze(-1).float()
return (emb * mask).sum(dim=-2) / mask.sum(dim=-2).clamp_min(1.0)
def encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor:
bsz, crops, channels, height, width = pixel_values.shape
flat = pixel_values.view(bsz * crops, channels, height, width)
vision_model = getattr(self.vision, "vision_model", self.vision)
vision_trainable = any(param.requires_grad for param in self.vision.parameters())
with torch.set_grad_enabled(vision_trainable):
try:
out = vision_model(pixel_values=flat, interpolate_pos_encoding=True)
except TypeError:
out = vision_model(pixel_values=flat)
tokens = out.last_hidden_state
tokens = tokens.view(bsz, crops, tokens.size(1), tokens.size(2))
return tokens
def reduce_visual_tokens(self, visual: torch.Tensor) -> torch.Tensor:
bsz, crops, num_tokens, hidden = visual.shape
max_visual_tokens = int(getattr(self.args, "max_visual_tokens", 0) or 0)
if max_visual_tokens <= 0 or crops * num_tokens <= max_visual_tokens:
return self.visual_proj(visual.flatten(1, 2))
per_crop = max(1, max_visual_tokens // max(1, crops))
reduced_crops: List[torch.Tensor] = []
for crop_idx in range(crops):
crop_tokens = visual[:, crop_idx, :, :]
grid = int(math.sqrt(num_tokens))
cls_token = None
patch_tokens = crop_tokens
if grid * grid != num_tokens:
maybe_grid = int(math.sqrt(num_tokens - 1))
if maybe_grid * maybe_grid == num_tokens - 1:
cls_token = crop_tokens[:, :1, :]
patch_tokens = crop_tokens[:, 1:, :]
grid = maybe_grid
else:
target = min(per_crop, num_tokens)
indices = torch.linspace(0, num_tokens - 1, steps=target, device=visual.device).round().long()
reduced_crops.append(crop_tokens.index_select(1, indices))
continue
cls_budget = 1 if cls_token is not None and per_crop > 1 else 0
patch_budget = max(1, per_crop - cls_budget)
out_grid = max(1, int(math.sqrt(patch_budget)))
patch_tokens = patch_tokens.view(bsz, grid, grid, hidden).permute(0, 3, 1, 2)
pooled = F.adaptive_avg_pool2d(patch_tokens, (out_grid, out_grid))
pooled = pooled.permute(0, 2, 3, 1).reshape(bsz, out_grid * out_grid, hidden)
if cls_budget:
pooled = torch.cat([cls_token, pooled], dim=1)
reduced_crops.append(pooled)
return self.visual_proj(torch.cat(reduced_crops, dim=1))
def roi_pool(self, full_tokens: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
bsz, num_tokens, hidden = full_tokens.shape
grid = int(math.sqrt(num_tokens))
if grid * grid != num_tokens:
patch_tokens = full_tokens
grid = int(math.sqrt(num_tokens - 1))
if grid * grid == num_tokens - 1:
patch_tokens = full_tokens[:, 1:, :]
else:
pooled = full_tokens.mean(dim=1, keepdim=True).expand(-1, bboxes.size(1), -1)
return pooled
else:
patch_tokens = full_tokens
patch_tokens = patch_tokens.view(bsz, grid, grid, hidden)
outputs = []
for b in range(bsz):
elem_vecs = []
for bbox in bboxes[b]:
x1, y1, x2, y2 = bbox.tolist()
ix1 = max(0, min(grid - 1, int(x1 * grid)))
iy1 = max(0, min(grid - 1, int(y1 * grid)))
ix2 = max(ix1 + 1, min(grid, int(math.ceil(x2 * grid))))
iy2 = max(iy1 + 1, min(grid, int(math.ceil(y2 * grid))))
region = patch_tokens[b, iy1:iy2, ix1:ix2, :]
elem_vecs.append(region.reshape(-1, hidden).mean(dim=0))
outputs.append(torch.stack(elem_vecs, dim=0))
return torch.stack(outputs, dim=0)
def build_memory(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
use_activation_checkpointing = self.training and bool(getattr(self.args, "activation_checkpointing", False))
bsz = batch["element_mask"].size(0)
if self.vision is not None:
pixel_values = batch["pixel_values"]
visual = self.encode_vision(pixel_values)
full_visual = visual[:, 0, :, :]
visual_tokens = self.reduce_visual_tokens(visual)
else:
full_visual = None
visual_tokens = batch["element_mask"].new_zeros((bsz, 0, self.hidden_size), dtype=self.decoder.get_input_embeddings().weight.dtype)
elem_ids = batch["element_input_ids"]
elem_mask_tok = batch["element_attention_mask"]
bsz, max_elements, elem_len = elem_ids.shape
elem_text = self.mean_embed(elem_ids.view(bsz * max_elements, elem_len), elem_mask_tok.view(bsz * max_elements, elem_len))
elem_text = self.elem_text_proj(elem_text.view(bsz, max_elements, -1))
bbox = batch["bboxes"]
wh = (bbox[..., 2:] - bbox[..., :2]).clamp_min(0)
bbox_feat = torch.cat([bbox, wh], dim=-1)
elem_tokens = (
elem_text
+ self.type_emb(batch["type_ids"])
+ self.source_emb(batch["source_ids"])
+ self.loc_emb(batch["loc_ids"])
+ self.bbox_proj(bbox_feat)
+ batch["confs"] * 0.1
+ batch["action_flags"] * 0.1
)
if self.vision is not None and self.model_variant in {"full", "late_fusion"}:
roi = self.roi_proj(self.roi_pool(full_visual, bbox))
elem_tokens = elem_tokens + roi
elem_key_padding = ~batch["element_mask"]
if use_activation_checkpointing:
elem_tokens = checkpoint(
lambda tokens, padding: self.layout_encoder(tokens, src_key_padding_mask=padding),
elem_tokens,
elem_key_padding,
use_reentrant=False,
)
else:
elem_tokens = self.layout_encoder(elem_tokens, src_key_padding_mask=elem_key_padding)
head_elem_tokens = elem_tokens
decoder_elem_tokens = elem_tokens
if bool(getattr(self.args, "function_signal_to_decoder", False)):
function_signal = torch.sigmoid(self.ui_function_head(head_elem_tokens)).detach()
function_signal = function_signal.masked_fill(elem_key_padding.unsqueeze(-1), 0.0)
signal_scale = float(getattr(self.args, "function_signal_scale", 1.0) or 1.0)
decoder_elem_tokens = decoder_elem_tokens + signal_scale * self.function_signal_proj(function_signal)
if bool(getattr(self.args, "search_signal_to_decoder", False)):
search_signal = torch.sigmoid(self.search_function_head(head_elem_tokens)).detach()
search_signal = search_signal.masked_fill(elem_key_padding.unsqueeze(-1), 0.0)
search_signal_scale = float(getattr(self.args, "search_signal_scale", 1.0) or 1.0)
decoder_elem_tokens = decoder_elem_tokens + search_signal_scale * self.search_signal_proj(search_signal)
context_mode = str(getattr(self.args, "context_mode", "mean") or "mean").lower()
direct_context_tokens = None
if context_mode in {"tokens_encoder", "tokens_direct_encoder"}:
context_encoded = self.decoder.get_encoder()(input_ids=batch["context_input_ids"], attention_mask=batch["context_attention_mask"], return_dict=True)
direct_context_tokens = context_encoded.last_hidden_state
context_tokens = self.context_proj(context_encoded.last_hidden_state)
context_padding = ~batch["context_attention_mask"].bool()
elif context_mode in {"tokens", "tokens_direct"}:
context_emb = self.decoder.get_input_embeddings()(batch["context_input_ids"])
direct_context_tokens = context_emb
context_tokens = self.context_proj(context_emb)
context_padding = ~batch["context_attention_mask"].bool()
else:
context = self.mean_embed(batch["context_input_ids"], batch["context_attention_mask"])
context_tokens = self.context_proj(context).unsqueeze(1)
context_padding = torch.zeros(bsz, 1, dtype=torch.bool, device=context_tokens.device)
context_len = context_tokens.size(1)
visual_start = -1
visual_len = 0
elem_start = -1
elem_len_for_fusion = 0
if self.model_variant == "annotation_only":
elem_start = context_len
elem_len_for_fusion = decoder_elem_tokens.size(1)
fusion_input = torch.cat([context_tokens, decoder_elem_tokens], dim=1)
fusion_padding = torch.cat([context_padding, elem_key_padding], dim=1)
elif self.model_variant == "image_only":
if self.vision is None:
raise ValueError("image_only requires vision; do not set disable_vision=true")
visual_start = 0
visual_len = visual_tokens.size(1)
fusion_input = visual_tokens
fusion_padding = torch.zeros(bsz, fusion_input.size(1), dtype=torch.bool, device=fusion_input.device)
elif self.model_variant == "late_fusion":
visual_summary = visual_tokens.mean(dim=1, keepdim=True)
visual_start = context_len
visual_len = 1
elem_start = context_len + 1
elem_len_for_fusion = decoder_elem_tokens.size(1)
fusion_input = torch.cat([context_tokens, visual_summary, decoder_elem_tokens], dim=1)
fusion_padding = torch.cat(
[context_padding, torch.zeros(bsz, 1, dtype=torch.bool, device=elem_key_padding.device), elem_key_padding],
dim=1,
)
else:
visual_start = context_len
visual_len = visual_tokens.size(1)
elem_start = context_len + visual_len
elem_len_for_fusion = decoder_elem_tokens.size(1)
fusion_input = torch.cat([context_tokens, visual_tokens, decoder_elem_tokens], dim=1)
fusion_padding = torch.cat(
[
context_padding,
torch.zeros(bsz, visual_tokens.size(1), dtype=torch.bool, device=elem_key_padding.device),
elem_key_padding,
],
dim=1,
)
if use_activation_checkpointing:
fused = checkpoint(
lambda tokens, padding: self.fusion(tokens, src_key_padding_mask=padding),
fusion_input,
fusion_padding,
use_reentrant=False,
)
else:
fused = self.fusion(fusion_input, src_key_padding_mask=fusion_padding)
if elem_start >= 0 and elem_len_for_fusion > 0:
head_elem_tokens = fused[:, elem_start : elem_start + elem_len_for_fusion, :]
if use_activation_checkpointing:
pooled = checkpoint(
lambda tokens, padding: self.pooler(tokens, key_padding_mask=padding),
fused,
fusion_padding,
use_reentrant=False,
)
else:
pooled = self.pooler(fused, key_padding_mask=fusion_padding)
pooled_padding = torch.zeros(bsz, pooled.size(1), dtype=torch.bool, device=pooled.device)
memory_parts = []
memory_padding_parts = []
if context_mode in {"tokens_direct", "tokens_direct_encoder"} and self.model_variant != "image_only":
if bool(getattr(self.args, "direct_context_passthrough", False)) and direct_context_tokens is not None:
memory_parts.append(direct_context_tokens)
else:
memory_parts.append(fused[:, :context_len, :])
memory_padding_parts.append(context_padding)
if bool(getattr(self.args, "direct_visual_tokens", False)) and visual_start >= 0 and visual_len > 0:
visual_scale = float(getattr(self.args, "visual_memory_scale", 1.0) or 1.0)
memory_parts.append(fused[:, visual_start : visual_start + visual_len, :] * visual_scale)
memory_padding_parts.append(torch.zeros(bsz, visual_len, dtype=torch.bool, device=pooled.device))
if bool(getattr(self.args, "direct_element_tokens", False)) and elem_start >= 0 and elem_len_for_fusion > 0:
element_scale = float(getattr(self.args, "element_memory_scale", 1.0) or 1.0)
memory_parts.append(fused[:, elem_start : elem_start + elem_len_for_fusion, :] * element_scale)
memory_padding_parts.append(elem_key_padding)
if bool(getattr(self.args, "include_pooled_memory", True)):
pooled_scale = float(getattr(self.args, "pooled_memory_scale", 1.0) or 1.0)
memory_parts.append(pooled * pooled_scale)
memory_padding_parts.append(pooled_padding)
if not memory_parts:
pooled_scale = float(getattr(self.args, "pooled_memory_scale", 1.0) or 1.0)
memory_parts.append(pooled * pooled_scale)
memory_padding_parts.append(pooled_padding)
memory = torch.cat(memory_parts, dim=1)
memory_scale = float(getattr(self.args, "decoder_memory_scale", 1.0) or 1.0)
if memory_scale != 1.0:
memory = memory * memory_scale
memory_padding = torch.cat(memory_padding_parts, dim=1)
memory_attention_mask = (~memory_padding).long()
return memory, memory_attention_mask, head_elem_tokens, elem_key_padding
def forward(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
if self.use_native_context_forward():
labels = batch["labels"]
decoder_out = self.decoder(
input_ids=batch["context_input_ids"],
attention_mask=batch["context_attention_mask"],
labels=labels,
use_cache=False,
)
gen_loss = decoder_out.loss
element_shape = batch["element_mask"].shape
empty_logits = labels.new_full(element_shape, -20.0, dtype=torch.float32)
section_logits = labels.new_zeros((labels.size(0), int(batch["section_labels"].shape[-1])), dtype=torch.float32)
zero_loss = gen_loss.detach().new_zeros(())
return {
"loss": gen_loss,
"generation_loss": gen_loss.detach(),
"evidence_loss": zero_loss,
"ui_function_loss": zero_loss,
"search_function_loss": zero_loss,
"section_loss": zero_loss,
"numeric_loss": zero_loss,
"evidence_logits": empty_logits,
"ui_function_logits": empty_logits,
"search_function_logits": empty_logits,
"section_logits": section_logits,
}
memory, memory_attention_mask, elem_tokens, elem_key_padding = self.build_memory(batch)
encoder_outputs = BaseModelOutput(last_hidden_state=memory)
labels = batch["labels"]
if hasattr(self.decoder, "prepare_decoder_input_ids_from_labels"):
decoder_input_ids = self.decoder.prepare_decoder_input_ids_from_labels(labels=labels)
elif hasattr(self.decoder, "_shift_right"):
decoder_input_ids = self.decoder._shift_right(labels)
else:
pad_token_id = int(getattr(self.decoder.config, "pad_token_id", 0) or 0)
start_token_id = int(getattr(self.decoder.config, "decoder_start_token_id", pad_token_id) or pad_token_id)
decoder_input_ids = labels.new_full(labels.shape, pad_token_id)
decoder_input_ids[:, 0] = start_token_id
decoder_input_ids[:, 1:] = labels[:, :-1].masked_fill(labels[:, :-1] == -100, pad_token_id)
decoder_out = self.decoder(
encoder_outputs=encoder_outputs,
attention_mask=memory_attention_mask,
decoder_input_ids=decoder_input_ids,
use_cache=False,
)
logits = decoder_out.logits
flat_logits = logits.reshape(-1, logits.size(-1))
flat_labels = labels.reshape(-1)
valid_token_count = flat_labels.ne(-100).sum().clamp_min(1)
chunk_size = max(1, int(getattr(self.args, "generation_loss_chunk_size", 32) or 32))
gen_loss_sum = logits.new_zeros((), dtype=torch.float32)
for start in range(0, flat_logits.size(0), chunk_size):
end = min(start + chunk_size, flat_logits.size(0))
gen_loss_sum = gen_loss_sum + F.cross_entropy(
flat_logits[start:end].float(),
flat_labels[start:end],
ignore_index=-100,
reduction="sum",
)
gen_loss = gen_loss_sum / valid_token_count
evidence_logits = self.evidence_head(elem_tokens).squeeze(-1)
ui_function_logits = self.ui_function_head(elem_tokens).squeeze(-1)
search_function_logits = self.search_function_head(elem_tokens).squeeze(-1)
numeric_logits = self.numeric_head(elem_tokens).squeeze(-1)
valid = (~elem_key_padding).float()
evidence_loss = F.binary_cross_entropy_with_logits(
evidence_logits, batch["evidence_labels"], reduction="none"
)
evidence_loss = (evidence_loss * valid).sum() / valid.sum().clamp_min(1.0)
ui_function_loss = F.binary_cross_entropy_with_logits(
ui_function_logits, batch["ui_function_labels"], reduction="none"
)
ui_function_loss = (ui_function_loss * valid).sum() / valid.sum().clamp_min(1.0)
search_pos_weight = float(getattr(self.args, "search_function_pos_weight", 1.0) or 1.0)
search_loss_kwargs: Dict[str, Any] = {"reduction": "none"}
if search_pos_weight != 1.0:
search_loss_kwargs["pos_weight"] = torch.tensor(search_pos_weight, device=search_function_logits.device)
search_function_loss = F.binary_cross_entropy_with_logits(
search_function_logits, batch["search_function_labels"], **search_loss_kwargs
)
search_function_loss = (search_function_loss * valid).sum() / valid.sum().clamp_min(1.0)
numeric_loss = F.binary_cross_entropy_with_logits(
numeric_logits, batch["numeric_labels"], reduction="none"
)
numeric_loss = (numeric_loss * valid).sum() / valid.sum().clamp_min(1.0)
memory_mask = memory_attention_mask.unsqueeze(-1).float()
memory_summary = (memory * memory_mask).sum(dim=1) / memory_mask.sum(dim=1).clamp_min(1.0)
section_logits = self.section_head(memory_summary)
section_loss = F.binary_cross_entropy_with_logits(section_logits, batch["section_labels"])
total = (
gen_loss
+ self.args.evidence_loss_weight * evidence_loss
+ float(getattr(self.args, "ui_function_loss_weight", 0.0) or 0.0) * ui_function_loss
+ float(getattr(self.args, "search_function_loss_weight", 0.0) or 0.0) * search_function_loss
+ self.args.section_loss_weight * section_loss
+ self.args.numeric_loss_weight * numeric_loss
)
return {
"loss": total,
"generation_loss": gen_loss.detach(),
"evidence_loss": evidence_loss.detach(),
"ui_function_loss": ui_function_loss.detach(),
"search_function_loss": search_function_loss.detach(),
"section_loss": section_loss.detach(),
"numeric_loss": numeric_loss.detach(),
"evidence_logits": evidence_logits.detach(),
"ui_function_logits": ui_function_logits.detach(),
"search_function_logits": search_function_logits.detach(),
"section_logits": section_logits.detach(),
}
@torch.no_grad()
def generate_text(self, batch: Dict[str, torch.Tensor], tokenizer, num_beams: int = 4, max_new_tokens: int = 384) -> List[str]:
# 评估时也按训练 amp_dtype 走 autocast,避免 mt5-large 在 valid 阶段以 fp32
# 跑生成时显存峰值远高于训练而 OOM,同时也明显加快评估速度。
amp_dtype_name = str(getattr(self.args, "amp_dtype", "auto") or "auto").lower()
if amp_dtype_name == "auto":
amp_dtype_name = "fp16" if bool(getattr(self.args, "fp16", False)) else "fp32"
device_is_cuda = batch["pixel_values"].is_cuda if torch.is_tensor(batch.get("pixel_values")) else False
amp_enabled = device_is_cuda and amp_dtype_name in {"fp16", "bf16"}
amp_dtype = torch.float16 if amp_dtype_name == "fp16" else torch.bfloat16
with torch.amp.autocast("cuda", enabled=amp_enabled, dtype=amp_dtype):
generation_kwargs = build_generation_kwargs(self.args, tokenizer)
if self.use_native_context_forward():
generated = self.decoder.generate(
input_ids=batch["context_input_ids"],
attention_mask=batch["context_attention_mask"],
num_beams=num_beams,
max_new_tokens=max_new_tokens,
**generation_kwargs,
)
else:
memory, memory_attention_mask, _, _ = self.build_memory(batch)
encoder_outputs = BaseModelOutput(last_hidden_state=memory)
generated = self.decoder.generate(
encoder_outputs=encoder_outputs,
attention_mask=memory_attention_mask,
num_beams=num_beams,
max_new_tokens=max_new_tokens,
**generation_kwargs,
)
return tokenizer.batch_decode(generated, skip_special_tokens=True)
def move_batch(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
out = {}
for key, value in batch.items():
if key == "rows":
out[key] = value
elif torch.is_tensor(value):
out[key] = value.to(device, non_blocking=True)
else:
out[key] = value
return out
def trainable_parameter_stats(model: nn.Module) -> Dict[str, Any]:
stats = {
"trainable_params_total": 0,
"trainable_params_vision": 0,
"trainable_params_decoder": 0,
"trainable_params_other": 0,
}
for name, param in model.named_parameters():
if not param.requires_grad:
continue
count = int(param.numel())
stats["trainable_params_total"] += count
if name.startswith("vision."):
stats["trainable_params_vision"] += count
elif name.startswith("decoder."):
stats["trainable_params_decoder"] += count
else:
stats["trainable_params_other"] += count
stats["vision_trainable"] = stats["trainable_params_vision"] > 0
return stats
def reduce_parallel_losses(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
for key in [
"loss",
"generation_loss",
"evidence_loss",
"ui_function_loss",
"search_function_loss",
"section_loss",
"numeric_loss",
]:
value = output.get(key)
if torch.is_tensor(value) and value.ndim > 0:
output[key] = value.mean()
return output
def build_optimizer(model: nn.Module, args: argparse.Namespace) -> torch.optim.Optimizer:
decoder_params = []
vision_params = []
ui_function_params = []
other_params = []
lr_ui_function_head = float(getattr(args, "lr_ui_function_head", 0.0) or 0.0)
for name, param in model.named_parameters():
if not param.requires_grad:
continue
if lr_ui_function_head > 0 and (
name.startswith("ui_function_head.")
or name.startswith("search_function_head.")
or name.startswith("function_signal_proj.")
or name.startswith("search_signal_proj.")
):
ui_function_params.append(param)
elif name.startswith("decoder."):
decoder_params.append(param)
elif name.startswith("vision."):
vision_params.append(param)
else:
other_params.append(param)
groups = [
{"params": other_params, "lr": args.lr_new},
{"params": vision_params, "lr": args.lr_fusion},
{"params": decoder_params, "lr": args.lr_decoder},
]
if ui_function_params:
groups.append({"params": ui_function_params, "lr": lr_ui_function_head})
optimizer_name = str(getattr(args, "optimizer_name", "adamw") or "adamw").lower()
if optimizer_name == "adamw":
return torch.optim.AdamW(groups, weight_decay=args.weight_decay)
if optimizer_name == "adafactor":
return Adafactor(
groups,
scale_parameter=False,
relative_step=False,
warmup_init=False,
weight_decay=args.weight_decay,
)
raise ValueError("optimizer_name must be one of: adamw, adafactor")
def clip_gradients(model: nn.Module, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> torch.Tensor:
max_grad_norm = float(getattr(args, "max_grad_norm", 1.0) or 0.0)
first_param = next(model.parameters())
if max_grad_norm <= 0:
return first_param.detach().new_tensor(0.0)
strategy = str(getattr(args, "grad_clip_strategy", "global") or "global").lower()
if strategy == "per_group":
norms: List[torch.Tensor] = []
for group in optimizer.param_groups:
params = [param for param in group["params"] if param.grad is not None]
if not params:
continue
group_norm = torch.nn.utils.clip_grad_norm_(params, max_grad_norm, foreach=False)
norms.append(group_norm.detach().to(device=first_param.device, dtype=torch.float32))
if not norms:
return first_param.detach().new_tensor(0.0)
return torch.stack(norms).max()
if strategy != "global":
raise ValueError("grad_clip_strategy must be one of: global, per_group")
return torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm, foreach=False)
def get_lr_schedule(optimizer, total_steps: int, warmup_ratio: float, scheduler_type: str = "linear"):
warmup = int(total_steps * warmup_ratio)
scheduler_type = str(scheduler_type or "linear").lower()
def lr_lambda(step: int) -> float:
if step < warmup:
return max(1e-8, step / max(1, warmup))
progress = (step - warmup) / max(1, total_steps - warmup)
if scheduler_type == "cosine":
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * min(1.0, progress))))
return max(0.0, 1.0 - progress)
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def save_checkpoint(
output_dir: Path,
name: str,
model: nn.Module,
tokenizer,
image_processor,
args: argparse.Namespace,
metrics: Dict[str, Any],
) -> None:
ckpt_dir = output_dir / name
ckpt_dir.mkdir(parents=True, exist_ok=True)
raw_model = unwrap_parallel_model(model)
torch.save(raw_model.state_dict(), ckpt_dir / "pytorch_model.bin")
write_json(ckpt_dir / "rich_config.json", vars(args))
write_json(ckpt_dir / "metrics.json", metrics)
tokenizer_dir = ckpt_dir / "decoder_tokenizer"
tokenizer.save_pretrained(tokenizer_dir)
vocab_file = getattr(tokenizer, "vocab_file", None)
if vocab_file and Path(vocab_file).exists() and not (tokenizer_dir / "spiece.model").exists():
shutil.copyfile(vocab_file, tokenizer_dir / "spiece.model")
decoder_model_dir = Path(str(getattr(args, "decoder_model", "") or ""))
if decoder_model_dir.is_dir():
for tokenizer_asset in ("spiece.model", "config.json"):
source_asset = decoder_model_dir / tokenizer_asset
target_asset = tokenizer_dir / tokenizer_asset
if source_asset.exists() and not target_asset.exists():
shutil.copyfile(source_asset, target_asset)
image_processor.save_pretrained(ckpt_dir / "image_processor")
def resize_checkpoint_tensor(value: torch.Tensor, target: torch.Tensor) -> Optional[torch.Tensor]:
if value.ndim != target.ndim or value.ndim == 0:
return None
resized = value.detach().cpu()
for dim, target_size in enumerate(target.shape):
current_size = resized.shape[dim]
if current_size == target_size:
continue
if current_size <= 0 or target_size <= 0:
return None
if current_size > target_size:
resized = resized.narrow(dim, 0, target_size)
continue
repeats = [1] * resized.ndim
repeats[dim] = math.ceil(target_size / current_size)
resized = resized.repeat(*repeats).narrow(dim, 0, target_size)
return resized.to(dtype=target.dtype)
def load_compatible_model_state(
model: nn.Module,
state: Dict[str, torch.Tensor],
allow_missing_prefixes: Tuple[str, ...] = (),
resize_mismatched_non_decoder: bool = False,
) -> Tuple[List[str], List[str], List[str], List[str]]:
model_state = model.state_dict()
compatible_state: Dict[str, torch.Tensor] = {}
resized_incompatible: List[str] = []
skipped_incompatible: List[str] = []
skipped_unexpected: List[str] = []
for key, value in state.items():
if key not in model_state:
skipped_unexpected.append(key)
continue
if tuple(model_state[key].shape) != tuple(value.shape):
if resize_mismatched_non_decoder and not key.startswith("decoder."):
resized_value = resize_checkpoint_tensor(value, model_state[key])
if resized_value is not None and tuple(resized_value.shape) == tuple(model_state[key].shape):
compatible_state[key] = resized_value
resized_incompatible.append(key)
continue
skipped_incompatible.append(key)
continue
compatible_state[key] = value
load_result = model.load_state_dict(compatible_state, strict=False)
allowed_missing = [
key
for key in load_result.missing_keys
if key.startswith("ui_function_head.")
or key.startswith("search_function_head.")
or key.startswith("function_signal_proj.")
or key.startswith("search_signal_proj.")
or key in skipped_incompatible
or key.startswith(allow_missing_prefixes)
]
bad_missing = [key for key in load_result.missing_keys if key not in allowed_missing]
if bad_missing or load_result.unexpected_keys or skipped_unexpected:
raise RuntimeError(
f"Checkpoint mismatch. missing={bad_missing}, unexpected={list(load_result.unexpected_keys) + skipped_unexpected}"
)
return list(load_result.missing_keys), list(load_result.unexpected_keys), skipped_incompatible, resized_incompatible
def load_rich_checkpoint(checkpoint: str, device: torch.device) -> Tuple[RichGroundedModel, Any, Any, argparse.Namespace]:
ckpt_dir = Path(checkpoint)
config = json.loads((ckpt_dir / "rich_config.json").read_text(encoding="utf-8"))
merged_config = dict(DEFAULT_CONFIG)
merged_config.update(config)
args = argparse.Namespace(**merged_config)
tokenizer = load_seq2seq_tokenizer(str(ckpt_dir / "decoder_tokenizer"), str(args.decoder_model))
image_processor = AutoImageProcessor.from_pretrained(ckpt_dir / "image_processor")
model = RichGroundedModel(args)
state = torch.load(ckpt_dir / "pytorch_model.bin", map_location="cpu")
load_compatible_model_state(model, state)
model.to(device)
model.eval()
return model, tokenizer, image_processor, args
def get_eval_max_new_tokens(args: argparse.Namespace) -> int:
return int(getattr(args, "eval_max_new_tokens", getattr(args, "max_target_tokens", 384)))
def build_generation_kwargs(args: argparse.Namespace, tokenizer) -> Dict[str, Any]:
kwargs: Dict[str, Any] = {}
no_repeat_ngram_size = int(getattr(args, "generation_no_repeat_ngram_size", 0) or 0)
if no_repeat_ngram_size > 0:
kwargs["no_repeat_ngram_size"] = no_repeat_ngram_size
repetition_penalty = float(getattr(args, "generation_repetition_penalty", 1.0) or 1.0)
if repetition_penalty != 1.0:
kwargs["repetition_penalty"] = repetition_penalty
min_new_tokens = int(getattr(args, "generation_min_new_tokens", 0) or 0)
if min_new_tokens > 0:
kwargs["min_new_tokens"] = min_new_tokens
bad_words_ids = []
if bool(getattr(args, "generation_block_extra_ids", False)):
vocab_len = len(tokenizer)
for token_id in range(max(0, vocab_len - 256), vocab_len):
token = tokenizer.convert_ids_to_tokens(token_id)
if isinstance(token, str) and "<extra_id_" in token:
bad_words_ids.append([token_id])
if bool(getattr(args, "generation_block_title_prefix", False)):
for phrase in ("Title", "title", "Title:", "title:", "Title :", "title :"):
token_ids = tokenizer.encode(phrase, add_special_tokens=False)
if token_ids:
bad_words_ids.append(token_ids)
if bad_words_ids:
deduped_bad_words = []
seen_bad_words = set()
for token_ids in bad_words_ids:
key = tuple(int(token_id) for token_id in token_ids)
if key in seen_bad_words:
continue
seen_bad_words.add(key)
deduped_bad_words.append(list(key))
kwargs["bad_words_ids"] = deduped_bad_words
if bool(getattr(args, "generation_force_json_start", False)):
json_start_ids = tokenizer.encode("{", add_special_tokens=False)
if len(json_start_ids) == 1:
kwargs["forced_bos_token_id"] = json_start_ids[0]
return kwargs
@torch.no_grad()
def evaluate(
model: nn.Module,
loader: DataLoader,
tokenizer,
device: torch.device,
args: argparse.Namespace,
rank: int = 0,
) -> Dict[str, Any]:
raw_model = unwrap_parallel_model(model)
raw_model.eval()
losses: List[float] = []
rouges: List[float] = []
json_valid = 0
generation_json_valid = 0
generation_json_strict_valid = 0
generation_json_repair_applied = 0
total = 0
evidence_precisions: List[float] = []
ui_function_tp = 0
ui_function_fp = 0
ui_function_fn = 0
ui_function_pred_positive = 0
ui_function_ref_positive = 0
ui_function_valid_elements = 0
generation_char_lengths: List[int] = []
pred_summary_char_lengths: List[int] = []
ref_summary_char_lengths: List[int] = []
empty_summary_count = 0
extra_id_count = 0
title_prefix_count = 0
natural_title_prefix_stripped_count = 0
whitespace_only_count = 0
search_function_tp = 0
search_function_fp = 0
search_function_fn = 0
search_function_pred_positive = 0
search_function_ref_positive = 0
search_tp = 0
search_fp = 0
search_fn = 0
pred_search_count = 0
ref_search_count = 0
pred_function_count = 0
ref_function_count = 0
pred_bare_search_count = 0
predictions: List[Dict[str, Any]] = []
eval_max_new_tokens = get_eval_max_new_tokens(args)
context_summary_repair = bool(getattr(args, "context_summary_repair", False))
structured_mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower()
target_schema = str(getattr(args, "target_schema", "zh") or "zh")
summary_output_mode = target_schema_is_summary(target_schema)
natural_output_mode = target_schema_is_natural_text(target_schema)
function_metric_threshold = 0.5
search_metric_threshold = 0.5
if structured_mode == "heads":
function_metric_threshold = float(getattr(args, "structured_function_threshold", 0.5) or 0.5)
search_metric_threshold = float(getattr(args, "structured_search_threshold", function_metric_threshold) or function_metric_threshold)
repair_count = 0
# 诊断指标:记录预测摘要是否与“这是一个{app}界面...当前任务语境是:{instruction}。”模板完全一致。
# 在 context_mode=tokens_direct 下 decoder 能直接看到 app+instruction token,只要模型能复制这些 token 到
# 模板位置,就能不看屏幕拿到高 ROUGE。template_match_rate 偏高表示训练信号实际是“拼模板”而非屏幕理解。
template_exact_match = 0
template_app_in_pred = 0
template_instruction_in_pred = 0
for batch in tqdm(loader, desc="valid", disable=not is_main(rank)):
rows = batch["rows"]
batch = move_batch(batch, device)
out = raw_model(**{k: v for k, v in batch.items() if k != "rows"})
losses.append(float(out["loss"].detach().cpu()))
valid_elements = batch["element_mask"].bool()
function_labels = batch["ui_function_labels"] > 0.5
function_preds = torch.sigmoid(out["ui_function_logits"]) >= function_metric_threshold
search_function_labels = batch["search_function_labels"] > 0.5
search_function_preds = torch.sigmoid(out["search_function_logits"]) >= search_metric_threshold
evidence_scores_batch = torch.sigmoid(out["evidence_logits"]).detach().cpu()
function_scores_batch = torch.sigmoid(out["ui_function_logits"]).detach().cpu()
search_scores_batch = torch.sigmoid(out["search_function_logits"]).detach().cpu()
ui_function_tp += int((function_preds & function_labels & valid_elements).sum().detach().cpu())
ui_function_fp += int((function_preds & ~function_labels & valid_elements).sum().detach().cpu())
ui_function_fn += int((~function_preds & function_labels & valid_elements).sum().detach().cpu())
ui_function_pred_positive += int((function_preds & valid_elements).sum().detach().cpu())
ui_function_ref_positive += int((function_labels & valid_elements).sum().detach().cpu())
search_function_tp += int((search_function_preds & search_function_labels & valid_elements).sum().detach().cpu())
search_function_fp += int((search_function_preds & ~search_function_labels & valid_elements).sum().detach().cpu())
search_function_fn += int((~search_function_preds & search_function_labels & valid_elements).sum().detach().cpu())
search_function_pred_positive += int((search_function_preds & valid_elements).sum().detach().cpu())
search_function_ref_positive += int((search_function_labels & valid_elements).sum().detach().cpu())
ui_function_valid_elements += int(valid_elements.sum().detach().cpu())
texts = raw_model.generate_text(batch, tokenizer, num_beams=args.num_beams, max_new_tokens=eval_max_new_tokens)
for row_idx, (row, text) in enumerate(zip(rows, texts)):
generation_char_lengths.append(len(text))
if text and not text.strip():
whitespace_only_count += 1
if "<extra_id_" in text:
extra_id_count += 1
if has_natural_title_prefix(text):
title_prefix_count += 1
if summary_output_mode:
pred_obj = prediction_from_summary(row, text)
ok = True
elif natural_output_mode:
parse_text, stripped_title_prefix = strip_natural_title_prefix(text)
natural_title_prefix_stripped_count += int(stripped_title_prefix)
pred_obj = natural_prediction_from_text(parse_text)
ok = bool(extract_summary(pred_obj))
generation_json_valid += int(ok)
else:
pred_obj, ok, repair_applied, strict_ok = safe_json_loads_with_repair(text)
generation_json_valid += int(ok)
generation_json_strict_valid += int(strict_ok)
generation_json_repair_applied += int(repair_applied)
if context_summary_repair and not summary_output_mode:
pred_obj, repaired = repair_prediction_with_context(row, pred_obj)
ok = True
repair_count += int(repaired)
pred_obj = apply_structured_function_predictions(
row,
pred_obj,
function_scores_batch[row_idx],
search_scores_batch[row_idx],
args,
)
pred_obj = apply_structured_evidence_predictions(row, pred_obj, evidence_scores_batch[row_idx], args)
ref_target = row.get("target") or {}
if bool(getattr(args, "canonicalize_targets", False)):
ref_target = canonicalize_target_with_context(
row,
ref_target,
drop_bare_search_functions=bool(getattr(args, "drop_bare_search_functions", False)),
)
ref_obj = json.loads(target_to_text(ref_target, "zh"))
pred_has_search = has_search_function(pred_obj)
ref_has_search = has_search_function(ref_obj)
search_tp += int(pred_has_search and ref_has_search)
search_fp += int(pred_has_search and not ref_has_search)
search_fn += int((not pred_has_search) and ref_has_search)
pred_search_count += count_search_functions(pred_obj)
ref_search_count += count_search_functions(ref_obj)
pred_function_count += len(extract_function_entries(pred_obj))
ref_function_count += len(extract_function_entries(ref_obj))
pred_bare_search_count += count_bare_search_functions(pred_obj)
pred_summary = extract_summary(pred_obj)
ref_summary = extract_summary(ref_obj)
pred_summary_char_lengths.append(len(pred_summary))
ref_summary_char_lengths.append(len(ref_summary))
if not pred_summary:
empty_summary_count += 1
rouges.append(rouge_l_char(pred_summary, ref_summary))
# 诊断:预测是否与 build_context_summary(row) 完全一致(拼模板)。
template_summary = build_context_summary(row)
if pred_summary and pred_summary == template_summary:
template_exact_match += 1
row_app = safe_text(row.get("app"))
row_instruction = safe_text(row.get("instruction"))
if row_app and pred_summary and row_app in pred_summary:
template_app_in_pred += 1
if row_instruction and pred_summary and row_instruction in pred_summary:
template_instruction_in_pred += 1
json_valid += int(ok)
total += 1
pred_evidence = set(extract_evidence_ids(pred_obj))
ref_evidence = set(ref_target.get("key_ui_clues", []) or row.get("weak_evidence_ids", []))
if pred_evidence:
evidence_precisions.append(len(pred_evidence & ref_evidence) / len(pred_evidence))
else:
evidence_precisions.append(0.0)
if len(predictions) < 50:
predictions.append(
{
"screen_id": row.get("screen_id"),
"prediction_raw": text,
"prediction": pred_obj,
"reference": ref_obj,
}
)
metrics = {
"loss": float(np.mean(losses)) if losses else 0.0,
"rouge_l_char": float(np.mean(rouges)) if rouges else 0.0,
"json_valid_rate": json_valid / max(1, total),
"generation_json_valid_rate": generation_json_valid / max(1, total),
"generation_json_strict_valid_rate": generation_json_strict_valid / max(1, total),
"generation_json_repair_rate": generation_json_repair_applied / max(1, total),
"generation_char_len_mean": float(np.mean(generation_char_lengths)) if generation_char_lengths else 0.0,
"generation_char_len_max": int(max(generation_char_lengths)) if generation_char_lengths else 0,
"pred_summary_char_len_mean": float(np.mean(pred_summary_char_lengths)) if pred_summary_char_lengths else 0.0,
"ref_summary_char_len_mean": float(np.mean(ref_summary_char_lengths)) if ref_summary_char_lengths else 0.0,
"empty_summary_rate": empty_summary_count / max(1, total),
"whitespace_only_rate": whitespace_only_count / max(1, total),
"extra_id_rate": extra_id_count / max(1, total),
"title_prefix_rate": title_prefix_count / max(1, total),
"natural_title_prefix_stripped_rate": natural_title_prefix_stripped_count / max(1, total),
"evidence_precision": float(np.mean(evidence_precisions)) if evidence_precisions else 0.0,
"ui_function_precision": ui_function_tp / max(1, ui_function_tp + ui_function_fp),
"ui_function_recall": ui_function_tp / max(1, ui_function_tp + ui_function_fn),
"ui_function_f1": (2 * ui_function_tp) / max(1, 2 * ui_function_tp + ui_function_fp + ui_function_fn),
"ui_function_pred_positive_rate": ui_function_pred_positive / max(1, ui_function_valid_elements),
"ui_function_ref_positive_rate": ui_function_ref_positive / max(1, ui_function_valid_elements),
"search_function_precision": search_function_tp / max(1, search_function_tp + search_function_fp),
"search_function_recall": search_function_tp / max(1, search_function_tp + search_function_fn),
"search_function_f1": (2 * search_function_tp) / max(1, 2 * search_function_tp + search_function_fp + search_function_fn),
"search_function_pred_positive_rate": search_function_pred_positive / max(1, ui_function_valid_elements),
"search_function_ref_positive_rate": search_function_ref_positive / max(1, ui_function_valid_elements),
"search_precision": search_tp / max(1, search_tp + search_fp),
"search_recall": search_tp / max(1, search_tp + search_fn),
"search_f1": (2 * search_tp) / max(1, 2 * search_tp + search_fp + search_fn),
"search_tp": search_tp,
"search_fp": search_fp,
"search_fn": search_fn,
"pred_search_count": pred_search_count,
"ref_search_count": ref_search_count,
"pred_function_count": pred_function_count,
"ref_function_count": ref_function_count,
"function_overgen_rate": max(0, pred_function_count - ref_function_count) / max(1, ref_function_count),
"function_count_ratio": pred_function_count / max(1, ref_function_count),
"search_overgen_rate": max(0, pred_search_count - ref_search_count) / max(1, total),
"pred_bare_search_count": pred_bare_search_count,
"bare_search_rate": pred_bare_search_count / max(1, pred_function_count),
"max_target_tokens": int(args.max_target_tokens),
"eval_max_new_tokens": eval_max_new_tokens,
"num_beams": int(args.num_beams),
"scheduler_epochs": int(getattr(args, "scheduler_epochs", 0) or 0),
"lr_scheduler_type": str(getattr(args, "lr_scheduler_type", "linear") or "linear"),
"context_mode": str(getattr(args, "context_mode", "mean") or "mean"),
"context_text_format": str(getattr(args, "context_text_format", "rich") or "rich"),
"context_include_screen_text": bool(getattr(args, "context_include_screen_text", False)),
"context_screen_text_items": int(getattr(args, "context_screen_text_items", 32) or 32),
"context_screen_text_dropout_rate": float(getattr(args, "context_screen_text_dropout_rate", 0.0) or 0.0),
"generation_block_extra_ids": bool(getattr(args, "generation_block_extra_ids", False)),
"generation_block_title_prefix": bool(getattr(args, "generation_block_title_prefix", False)),
"generation_no_repeat_ngram_size": int(getattr(args, "generation_no_repeat_ngram_size", 0) or 0),
"generation_repetition_penalty": float(getattr(args, "generation_repetition_penalty", 1.0) or 1.0),
"generation_min_new_tokens": int(getattr(args, "generation_min_new_tokens", 0) or 0),
"generation_force_json_start": bool(getattr(args, "generation_force_json_start", False)),
"context_summary_repair": context_summary_repair,
"context_repair_applied_rate": repair_count / max(1, total),
"canonicalize_targets": bool(getattr(args, "canonicalize_targets", False)),
"target_schema": str(getattr(args, "target_schema", "zh") or "zh"),
"decoder_output_mode": "summary" if summary_output_mode else ("natural_text" if natural_output_mode else "json"),
"task_intent_context": bool(getattr(args, "task_intent_context", False)),
"drop_bare_search_functions": bool(getattr(args, "drop_bare_search_functions", False)),
"structured_function_mode": str(getattr(args, "structured_function_mode", "decoder") or "decoder"),
"structured_function_threshold": float(getattr(args, "structured_function_threshold", 0.5) or 0.5),
"structured_search_threshold": float(getattr(args, "structured_search_threshold", 0.5) or 0.5),
"structured_max_functions": int(getattr(args, "structured_max_functions", 12) or 12),
"structured_strict_search_candidates": bool(getattr(args, "structured_strict_search_candidates", False)),
"structured_evidence_mode": str(getattr(args, "structured_evidence_mode", "decoder") or "decoder"),
"structured_evidence_threshold": float(getattr(args, "structured_evidence_threshold", 0.5) or 0.5),
"structured_max_evidence": int(getattr(args, "structured_max_evidence", 8) or 8),
"structured_evidence_fallback_top1": bool(getattr(args, "structured_evidence_fallback_top1", True)),
"ui_function_loss_weight": float(getattr(args, "ui_function_loss_weight", 0.0) or 0.0),
"search_function_loss_weight": float(getattr(args, "search_function_loss_weight", 0.0) or 0.0),
"search_function_pos_weight": float(getattr(args, "search_function_pos_weight", 1.0) or 1.0),
"lr_ui_function_head": float(getattr(args, "lr_ui_function_head", 0.0) or 0.0),
"max_visual_tokens": int(getattr(args, "max_visual_tokens", 0) or 0),
"model_variant": str(getattr(args, "model_variant", "")),
"vision_enabled": not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only",
"image_size": int(getattr(args, "image_size", 0) or 0),
"direct_visual_tokens": bool(getattr(args, "direct_visual_tokens", False)),
"direct_element_tokens": bool(getattr(args, "direct_element_tokens", False)),
"direct_context_passthrough": bool(getattr(args, "direct_context_passthrough", False)),
"include_pooled_memory": bool(getattr(args, "include_pooled_memory", True)),
"native_context_forward": bool(getattr(args, "native_context_forward", False)),
"disable_vision": bool(getattr(args, "disable_vision", False)),
"freeze_decoder": bool(getattr(args, "freeze_decoder", False)),
"init_resize_mismatched_non_decoder": bool(getattr(args, "init_resize_mismatched_non_decoder", False)),
"grad_clip_strategy": str(getattr(args, "grad_clip_strategy", "global") or "global"),
"max_grad_norm": float(getattr(args, "max_grad_norm", 1.0) or 0.0),
"function_signal_to_decoder": bool(getattr(args, "function_signal_to_decoder", False)),
"function_signal_scale": float(getattr(args, "function_signal_scale", 1.0) or 1.0),
"search_signal_to_decoder": bool(getattr(args, "search_signal_to_decoder", False)),
"search_signal_scale": float(getattr(args, "search_signal_scale", 1.0) or 1.0),
"visual_memory_scale": float(getattr(args, "visual_memory_scale", 1.0) or 1.0),
"element_memory_scale": float(getattr(args, "element_memory_scale", 1.0) or 1.0),
"pooled_memory_scale": float(getattr(args, "pooled_memory_scale", 1.0) or 1.0),
"decoder_gradient_checkpointing": bool(getattr(args, "decoder_gradient_checkpointing", False)),
"vision_gradient_checkpointing": bool(getattr(args, "vision_gradient_checkpointing", False)),
"cuda_memory_fraction": float(getattr(args, "cuda_memory_fraction", 0.0) or 0.0),
"max_train_samples": int(getattr(args, "max_train_samples", 0) or 0),
}
# summary 模式下 decoder 只输出摘要:
# - json_valid_rate 在 prediction_from_summary 里被无条件置 True,恒为 1.0,
# 给 rich_quality_score 贡献白送的 0.25,掩盖真实 rouge;
# - evidence_precision 来自辅助 head(apply_structured_evidence_predictions),
# 与正在 selection 的 decoder 输出无关,混入会让 checkpoint 选择被 head 拉偏。
# 因此 summary 模式下 rich_quality_score 折叠为纯 rouge_l_char,并显式暴露 summary_quality_score;
# JSON 模式保留原加权,那时三项才都来自 decoder 输出本身。
metrics["summary_quality_score"] = float(metrics["rouge_l_char"])
# 诊断指标输出:
# - summary_template_match_rate 越高表示预测越是在拼“这是一个X界面...当前任务语境是...”模板,
# 该值接近 1 + 高 ROUGE 应读作“模型未学到屏幕理解,只是复制 context token”。
# - summary_app_recall_rate / summary_instruction_recall_rate 反映 app/instruction 是否被原样复制进去,
# 两者同时接近 1 指向同一拼接行为。
metrics["summary_template_match_rate"] = template_exact_match / max(1, total)
metrics["summary_app_recall_rate"] = template_app_in_pred / max(1, total)
metrics["summary_instruction_recall_rate"] = template_instruction_in_pred / max(1, total)
if summary_output_mode:
metrics["rich_quality_score"] = float(metrics["rouge_l_char"])
else:
metrics["rich_quality_score"] = (
0.45 * metrics["rouge_l_char"]
+ 0.25 * metrics["json_valid_rate"]
+ 0.30 * metrics["evidence_precision"]
)
metrics["rich_function_score"] = (
0.70 * metrics["rich_quality_score"]
+ 0.20 * metrics["search_f1"]
+ 0.10 * metrics["ui_function_f1"]
)
overgen_penalty = min(0.12, 0.04 * float(metrics["function_overgen_rate"]))
search_penalty = min(0.06, 0.02 * float(metrics["search_overgen_rate"]))
bare_search_penalty = min(0.03, 0.03 * float(metrics["bare_search_rate"]))
metrics["grounded_quality_score"] = max(
0.0,
float(metrics["summary_quality_score"]) + 0.05 * float(metrics["evidence_precision"]) - overgen_penalty - search_penalty - bare_search_penalty,
)
metrics["grounded_overgen_penalty"] = overgen_penalty + search_penalty + bare_search_penalty
if is_main(rank):
write_json(Path(args.output_dir) / "val_preview.json", {"metrics": metrics, "predictions": predictions})
raw_model.train()
return metrics
def selection_metric_name(args: argparse.Namespace) -> str:
return str(getattr(args, "model_selection_metric", "rich_quality_score") or "rich_quality_score")
def selection_metric_value(metrics: Dict[str, Any], args: argparse.Namespace) -> float:
metric_name = selection_metric_name(args)
value = metrics.get(metric_name)
if value is None:
value = metrics.get("rich_quality_score", 0.0)
return float(value)
def metric_is_better(current: float, best: float, args: argparse.Namespace) -> bool:
mode = str(getattr(args, "model_selection_mode", "max") or "max").lower()
min_delta = float(getattr(args, "early_stopping_min_delta", 0.0) or 0.0)
if mode == "min":
return current < best - min_delta
return current > best + min_delta
def train(args: argparse.Namespace) -> None:
set_seed(args.seed)
distributed, rank, world_size, local_rank = init_distributed()
device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
if device.type == "cuda":
cuda_memory_fraction = float(getattr(args, "cuda_memory_fraction", 0.0) or 0.0)
if cuda_memory_fraction < 0.0 or cuda_memory_fraction > 1.0:
raise ValueError("cuda_memory_fraction must be in [0, 1]. Use 0 to disable the limit.")
if cuda_memory_fraction > 0.0:
torch.cuda.set_per_process_memory_fraction(cuda_memory_fraction, device=device)
if is_main(rank):
total_gb = torch.cuda.get_device_properties(device).total_memory / 1024**3
print(f"CUDA memory fraction limit: {cuda_memory_fraction:.3f} (~{total_gb * cuda_memory_fraction:.2f}GB of {total_gb:.2f}GB)")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
output_dir = Path(args.output_dir)
if is_main(rank):
output_dir.mkdir(parents=True, exist_ok=True)
write_json(output_dir / "config.json", vars(args))
max_train_samples = int(getattr(args, "max_train_samples", 0) or 0)
train_dataset = RichScreenshotDataset(args.train_file, max_samples=max_train_samples, sample_seed=args.seed if max_train_samples else None)
valid_dataset = RichScreenshotDataset(args.valid_file, max_samples=args.max_valid_samples)
train_data_diagnostics = dataset_diagnostics(train_dataset.rows)
valid_data_diagnostics = dataset_diagnostics(valid_dataset.rows)
validate_dataset_for_training("train", train_data_diagnostics, args)
validate_dataset_for_training("valid", valid_data_diagnostics, args)
tokenizer = load_seq2seq_tokenizer(args.decoder_model)
train_token_diagnostics = tokenizer_diagnostics(train_dataset.rows, tokenizer, args)
valid_token_diagnostics = tokenizer_diagnostics(valid_dataset.rows, tokenizer, args)
validate_token_lengths("train", train_token_diagnostics, args)
validate_token_lengths("valid", valid_token_diagnostics, args)
if is_main(rank):
diagnostics_payload = {
"train_file": args.train_file,
"valid_file": args.valid_file,
"strict_data_checks": bool(getattr(args, "strict_data_checks", True)),
"train": train_data_diagnostics,
"valid": valid_data_diagnostics,
"tokenizer": args.decoder_model,
"max_target_tokens": int(getattr(args, "max_target_tokens", 0) or 0),
"eval_max_new_tokens": int(getattr(args, "eval_max_new_tokens", 0) or 0),
"max_target_truncation_rate": float(getattr(args, "max_target_truncation_rate", 0.01) or 0.0),
"train_token_lengths": train_token_diagnostics,
"valid_token_lengths": valid_token_diagnostics,
}
write_json(output_dir / "data_diagnostics.json", diagnostics_payload)
print("Data diagnostics:")
print(json.dumps(diagnostics_payload, ensure_ascii=False, indent=2))
image_processor = AutoImageProcessor.from_pretrained(args.vision_model)
train_collator = RichCollator(tokenizer, image_processor, args, is_training=True)
valid_collator = RichCollator(tokenizer, image_processor, args, is_training=False)
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=args.seed) if distributed else None
valid_sampler = DistributedSampler(valid_dataset, shuffle=False) if distributed else None
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
shuffle=train_sampler is None,
sampler=train_sampler,
collate_fn=train_collator,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
)
valid_loader = DataLoader(
valid_dataset,
batch_size=args.eval_batch_size if getattr(args, "eval_batch_size", 0) else max(1, args.batch_size // 2),
shuffle=False,
sampler=valid_sampler,
collate_fn=valid_collator,
num_workers=args.num_workers,
pin_memory=torch.cuda.is_available(),
)
model = RichGroundedModel(args)
init_checkpoint = str(getattr(args, "init_checkpoint", "") or "")
if init_checkpoint:
checkpoint_decoder_model = ""
init_config_path = Path(init_checkpoint) / "rich_config.json"
init_config: Dict[str, Any] = {}
if init_config_path.exists():
init_config = json.loads(init_config_path.read_text(encoding="utf-8"))
checkpoint_decoder_model = str(init_config.get("decoder_model", "") or "")
allow_missing_prefixes_list: List[str] = []
if checkpoint_decoder_model and checkpoint_decoder_model != str(getattr(args, "decoder_model", "") or ""):
allow_missing_prefixes_list.append("decoder.")
if bool(init_config.get("disable_vision", False)) and not bool(getattr(args, "disable_vision", False)):
allow_missing_prefixes_list.append("vision.")
state = torch.load(Path(init_checkpoint) / "pytorch_model.bin", map_location="cpu")
missing_keys, _, skipped_incompatible, resized_incompatible = load_compatible_model_state(
model,
state,
allow_missing_prefixes=tuple(allow_missing_prefixes_list),
resize_mismatched_non_decoder=bool(getattr(args, "init_resize_mismatched_non_decoder", False)),
)
if is_main(rank):
missing_preview = missing_keys[:12]
skipped_preview = skipped_incompatible[:12]
resized_preview = resized_incompatible[:12]
print(
f"Loaded init checkpoint: {init_checkpoint}; "
f"missing_count={len(missing_keys)} preview={missing_preview}; "
f"skipped_incompatible_count={len(skipped_incompatible)} preview={skipped_preview}; "
f"resized_incompatible_count={len(resized_incompatible)} preview={resized_preview}"
)
model = model.to(device)
trainable_stats = trainable_parameter_stats(model)
if is_main(rank):
print(
"Trainable parameters: "
f"total={trainable_stats['trainable_params_total']:,}, "
f"vision={trainable_stats['trainable_params_vision']:,}, "
f"decoder={trainable_stats['trainable_params_decoder']:,}, "
f"other={trainable_stats['trainable_params_other']:,}",
flush=True,
)
if distributed:
model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
elif bool(getattr(args, "data_parallel", False)) and device.type == "cuda" and torch.cuda.device_count() > 1:
model = nn.DataParallel(model)
if is_main(rank):
print(f"Using nn.DataParallel on {torch.cuda.device_count()} CUDA devices.", flush=True)
optimizer = build_optimizer(unwrap_parallel_model(model), args)
scheduler_epochs = int(getattr(args, "scheduler_epochs", 0) or args.epochs)
scheduler_epochs = max(int(args.epochs), scheduler_epochs)
total_update_steps = math.ceil(len(train_loader) / args.grad_accum) * scheduler_epochs
scheduler = get_lr_schedule(optimizer, total_update_steps, args.warmup_ratio, getattr(args, "lr_scheduler_type", "linear"))
amp_dtype_name = str(getattr(args, "amp_dtype", "auto") or "auto").lower()
if amp_dtype_name == "auto":
amp_dtype_name = "fp16" if bool(getattr(args, "fp16", False)) else "fp32"
if amp_dtype_name not in {"fp32", "fp16", "bf16"}:
raise ValueError("amp_dtype must be one of: auto, fp32, fp16, bf16")
amp_enabled = device.type == "cuda" and amp_dtype_name in {"fp16", "bf16"}
amp_dtype = torch.float16 if amp_dtype_name == "fp16" else torch.bfloat16
scaler = torch.amp.GradScaler("cuda", enabled=amp_enabled and amp_dtype_name == "fp16")
model_selection_mode = str(getattr(args, "model_selection_mode", "max") or "max").lower()
best_score = math.inf if model_selection_mode == "min" else -math.inf
epochs_without_improvement = 0
global_step = 0
train_start_time = time.time()
recent_losses: deque[float] = deque(maxlen=100)
optimizer.zero_grad(set_to_none=True)
if device.type == "cuda":
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats(device)
skip_optimizer_window = False
skipped_nonfinite_windows = 0
for epoch in range(args.epochs):
if train_sampler is not None:
train_sampler.set_epoch(epoch)
progress = tqdm(train_loader, desc=f"epoch {epoch + 1}/{args.epochs}", disable=not is_main(rank))
for step, batch in enumerate(progress, start=1):
batch = move_batch(batch, device)
with torch.amp.autocast("cuda", enabled=amp_enabled, dtype=amp_dtype):
out = model(**{k: v for k, v in batch.items() if k != "rows"})
out = reduce_parallel_losses(out)
loss = out["loss"] / args.grad_accum
loss_is_finite = bool(torch.isfinite(loss.detach()).all().item())
if not loss_is_finite:
skip_optimizer_window = True
optimizer.zero_grad(set_to_none=True)
if is_main(rank):
append_jsonl(
output_dir / "metrics.jsonl",
{
"event": "skip_nonfinite_microbatch",
"epoch": epoch + 1,
"micro_step": step,
"loss": float(out["loss"].detach().cpu()),
"generation_loss": float(out["generation_loss"].detach().cpu()),
"evidence_loss": float(out["evidence_loss"].detach().cpu()),
"ui_function_loss": float(out["ui_function_loss"].detach().cpu()),
"search_function_loss": float(out["search_function_loss"].detach().cpu()),
"section_loss": float(out["section_loss"].detach().cpu()),
"numeric_loss": float(out["numeric_loss"].detach().cpu()),
},
)
elif not skip_optimizer_window:
scaler.scale(loss).backward()
step_log = None
if step % args.grad_accum == 0:
if skip_optimizer_window:
skipped_nonfinite_windows += 1
optimizer.zero_grad(set_to_none=True)
skip_optimizer_window = False
if is_main(rank):
append_jsonl(
output_dir / "metrics.jsonl",
{
"event": "skip_nonfinite_optimizer_window",
"epoch": epoch + 1,
"micro_step": step,
"skipped_nonfinite_windows": skipped_nonfinite_windows,
"lr": max(scheduler.get_last_lr()),
},
)
else:
lrs_used = list(scheduler.get_last_lr())
scaler.unscale_(optimizer)
grad_norm = clip_gradients(model, optimizer, args)
grad_is_finite = bool(torch.isfinite(grad_norm.detach()).all().item()) if torch.is_tensor(grad_norm) else math.isfinite(float(grad_norm))
if not grad_is_finite:
skipped_nonfinite_windows += 1
optimizer.zero_grad(set_to_none=True)
if scaler.is_enabled():
scaler.update()
if is_main(rank):
append_jsonl(
output_dir / "metrics.jsonl",
{
"event": "skip_nonfinite_grad",
"epoch": epoch + 1,
"micro_step": step,
"grad_norm": float(grad_norm.detach().cpu()) if torch.is_tensor(grad_norm) else float(grad_norm),
"skipped_nonfinite_windows": skipped_nonfinite_windows,
"lr": max(scheduler.get_last_lr()),
},
)
else:
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
global_step += 1
if is_main(rank):
loss_value = float(out["loss"].detach().cpu())
recent_losses.append(loss_value)
label_counts = (batch["labels"] != -100).sum(dim=1).detach().float().cpu()
context_counts = batch["context_attention_mask"].sum(dim=1).detach().float().cpu()
element_counts = batch["element_mask"].sum(dim=1).detach().float().cpu()
loss_window = list(recent_losses)
loss_window_20 = loss_window[-20:]
elapsed_seconds = max(time.time() - train_start_time, 1e-6)
next_lrs = list(scheduler.get_last_lr())
step_log = {
"step": global_step,
"epoch": epoch + 1,
"loss": loss_value,
"loss_ma20": float(np.mean(loss_window_20)) if loss_window_20 else loss_value,
"loss_ma100": float(np.mean(loss_window)) if loss_window else loss_value,
"generation_loss": float(out["generation_loss"].cpu()),
"evidence_loss": float(out["evidence_loss"].cpu()),
"ui_function_loss": float(out["ui_function_loss"].cpu()),
"search_function_loss": float(out["search_function_loss"].cpu()),
"section_loss": float(out["section_loss"].cpu()),
"numeric_loss": float(out["numeric_loss"].cpu()),
"lr": max(lrs_used) if lrs_used else 0.0,
"lr_other": lrs_used[0] if len(lrs_used) > 0 else 0.0,
"lr_vision": lrs_used[1] if len(lrs_used) > 1 else 0.0,
"lr_decoder": lrs_used[2] if len(lrs_used) > 2 else 0.0,
"lr_ui_function_head": lrs_used[3] if len(lrs_used) > 3 else 0.0,
"lr_next": max(next_lrs) if next_lrs else 0.0,
"lr_scheduler_type": str(getattr(args, "lr_scheduler_type", "linear") or "linear"),
"grad_norm": float(grad_norm.detach().cpu()) if torch.is_tensor(grad_norm) else float(grad_norm),
"skipped_nonfinite_windows": skipped_nonfinite_windows,
"optimizer_steps_total": total_update_steps,
"train_progress_pct": 100.0 * global_step / max(1, total_update_steps),
"elapsed_seconds": elapsed_seconds,
"optimizer_steps_per_sec": global_step / elapsed_seconds,
"target_tokens_mean": float(label_counts.mean().item()) if label_counts.numel() else 0.0,
"target_tokens_max": int(label_counts.max().item()) if label_counts.numel() else 0,
"target_at_max_rate": float((label_counts >= int(args.max_target_tokens)).float().mean().item()) if label_counts.numel() else 0.0,
"context_tokens_mean": float(context_counts.mean().item()) if context_counts.numel() else 0.0,
"context_tokens_max": int(context_counts.max().item()) if context_counts.numel() else 0,
"context_at_max_rate": float((context_counts >= int(args.max_context_tokens)).float().mean().item()) if context_counts.numel() else 0.0,
"element_count_mean": float(element_counts.mean().item()) if element_counts.numel() else 0.0,
"element_count_max": int(element_counts.max().item()) if element_counts.numel() else 0,
"model_variant": str(getattr(args, "model_variant", "")),
"vision_enabled": not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only",
"native_context_forward": bool(getattr(args, "native_context_forward", False)),
"freeze_decoder": bool(getattr(args, "freeze_decoder", False)),
"image_size": int(getattr(args, "image_size", 0) or 0),
"image_crops": int(batch["pixel_values"].shape[1]) if torch.is_tensor(batch.get("pixel_values")) and batch["pixel_values"].ndim >= 5 else 0,
"max_visual_tokens": int(getattr(args, "max_visual_tokens", 0) or 0),
"direct_visual_tokens": bool(getattr(args, "direct_visual_tokens", False)),
"direct_element_tokens": bool(getattr(args, "direct_element_tokens", False)),
"visual_memory_scale": float(getattr(args, "visual_memory_scale", 1.0) or 1.0),
"element_memory_scale": float(getattr(args, "element_memory_scale", 1.0) or 1.0),
"pooled_memory_scale": float(getattr(args, "pooled_memory_scale", 1.0) or 1.0),
"cuda_memory_fraction": float(getattr(args, "cuda_memory_fraction", 0.0) or 0.0),
"data_parallel": isinstance(model, nn.DataParallel),
"distributed": bool(distributed),
**trainable_stats,
}
del loss
del out
del batch
empty_cache_steps = int(getattr(args, "cuda_empty_cache_steps", 0) or 0)
if (
device.type == "cuda"
and empty_cache_steps > 0
and step % args.grad_accum == 0
and global_step % empty_cache_steps == 0
):
gc.collect()
torch.cuda.empty_cache()
if step_log is not None and is_main(rank):
mem = {}
if torch.cuda.is_available():
mem = {
"gpu_allocated_gb": round(torch.cuda.memory_allocated(device) / 1024**3, 3),
"gpu_reserved_gb": round(torch.cuda.memory_reserved(device) / 1024**3, 3),
"gpu_peak_allocated_gb": round(torch.cuda.max_memory_allocated(device) / 1024**3, 3),
"gpu_peak_reserved_gb": round(torch.cuda.max_memory_reserved(device) / 1024**3, 3),
}
log = {
**step_log,
**mem,
}
append_jsonl(output_dir / "metrics.jsonl", log)
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats(device)
progress.set_postfix(loss=f"{log['loss']:.3f}", ma20=f"{log['loss_ma20']:.3f}", pct=f"{log['train_progress_pct']:.1f}%")
if (
step_log is not None
and args.save_checkpoints
and args.save_every_steps
and global_step > 0
and global_step % args.save_every_steps == 0
and is_main(rank)
):
save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"step": global_step})
if step_log is not None and args.eval_every_steps and global_step > 0 and global_step % args.eval_every_steps == 0:
if is_main(rank) and args.save_checkpoints:
save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"step": global_step, "pre_eval": True})
metrics = evaluate(model, valid_loader, tokenizer, device, args, rank=rank)
current_score = selection_metric_value(metrics, args)
metrics["selection_metric"] = selection_metric_name(args)
metrics["selection_score"] = current_score
if metric_is_better(current_score, best_score, args):
best_score = current_score
if is_main(rank) and args.save_checkpoints:
save_checkpoint(output_dir, "checkpoint-best", model, tokenizer, image_processor, args, metrics)
if device.type == "cuda" and empty_cache_steps > 0:
gc.collect()
torch.cuda.empty_cache()
if is_main(rank):
append_jsonl(output_dir / "metrics.jsonl", {"step_eval": global_step, **metrics})
if is_main(rank) and args.save_checkpoints:
save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"epoch": epoch + 1, "step": global_step, "pre_epoch_eval": True})
metrics = evaluate(model, valid_loader, tokenizer, device, args, rank=rank)
current_score = selection_metric_value(metrics, args)
metrics["selection_metric"] = selection_metric_name(args)
metrics["selection_score"] = current_score
improved = metric_is_better(current_score, best_score, args)
if improved:
best_score = current_score
epochs_without_improvement = 0
else:
epochs_without_improvement += 1
if is_main(rank):
append_jsonl(output_dir / "metrics.jsonl", {"epoch_eval": epoch + 1, **metrics})
if args.save_checkpoints:
save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, metrics)
if args.save_checkpoints and improved:
save_checkpoint(output_dir, "checkpoint-best", model, tokenizer, image_processor, args, metrics)
patience = int(getattr(args, "early_stopping_patience", 0) or 0)
if patience > 0:
print(
f"selection_metric={metrics['selection_metric']} score={current_score:.6f} "
f"best={best_score:.6f} no_improve={epochs_without_improvement}/{patience}"
)
patience = int(getattr(args, "early_stopping_patience", 0) or 0)
if patience > 0 and epochs_without_improvement >= patience:
if is_main(rank):
print(f"Early stopping at epoch {epoch + 1}: no improvement for {patience} evals.")
break
if distributed:
dist.destroy_process_group()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Train rich CMGUI screenshot summarization models.")
for key, value in DEFAULT_CONFIG.items():
arg_type = type(value)
if isinstance(value, bool):
parser.add_argument(f"--{key}", type=lambda x: str(x).lower() in {"1", "true", "yes"}, default=value)
else:
parser.add_argument(f"--{key}", type=arg_type, default=value)
parser.add_argument("--bottleneck_queries", type=int, default=64)
args = parser.parse_args()
args.decoder_model = normalize_model_reference(args.decoder_model)
args.vision_model = normalize_model_reference(args.vision_model)
if args.model_variant not in {"annotation_only", "image_only", "late_fusion", "full"}:
raise ValueError("model_variant must be one of annotation_only, image_only, late_fusion, full")
args.context_mode = str(getattr(args, "context_mode", "mean") or "mean").lower()
if args.context_mode not in {"mean", "tokens", "tokens_direct", "tokens_encoder", "tokens_direct_encoder"}:
raise ValueError("context_mode must be one of: mean, tokens, tokens_direct, tokens_encoder, tokens_direct_encoder")
args.lr_scheduler_type = str(getattr(args, "lr_scheduler_type", "linear") or "linear").lower()
if args.lr_scheduler_type not in {"linear", "cosine"}:
raise ValueError("lr_scheduler_type must be one of: linear, cosine")
args.target_schema = str(getattr(args, "target_schema", "zh") or "zh").lower()
if args.target_schema not in {
"zh",
"alias",
"aliases",
"en",
"english",
"summary",
"summary_zh",
"summary-only",
"summary_only",
"natural_zh",
"rich_text_zh",
"zh_text",
"text_zh",
"summary_visible_zh",
"natural_summary_visible_zh",
}:
raise ValueError("target_schema must be one of: zh, aliases, summary_zh, natural_zh, summary_visible_zh")
args.grad_clip_strategy = str(getattr(args, "grad_clip_strategy", "global") or "global").lower()
if args.grad_clip_strategy not in {"global", "per_group"}:
raise ValueError("grad_clip_strategy must be one of: global, per_group")
args.structured_function_mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower()
if args.structured_function_mode not in {"decoder", "heads"}:
raise ValueError("structured_function_mode must be one of: decoder, heads")
args.structured_evidence_mode = str(getattr(args, "structured_evidence_mode", "decoder") or "decoder").lower()
if args.structured_evidence_mode not in {"decoder", "heads"}:
raise ValueError("structured_evidence_mode must be one of: decoder, heads")
return args
if __name__ == "__main__":
train(parse_args())