#!/usr/bin/env python """Train rich Chinese screenshot summarization models on CMGUI-style data.""" from __future__ import annotations import argparse import gc import json import math import os import random import re import shutil import time from collections import deque from dataclasses import asdict, dataclass from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple import numpy as np import torch import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F from PIL import Image from torch.nn.parallel import DistributedDataParallel from torch.utils.checkpoint import checkpoint from torch.utils.data import DataLoader, Dataset, DistributedSampler from tqdm import tqdm from transformers import Adafactor, AutoImageProcessor, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, T5Tokenizer from transformers.modeling_outputs import BaseModelOutput DEFAULT_CONFIG = { # 数据文件。正式训练使用 train_rich/valid_rich,先用 smoke 文件调通。 "train_file": "data/rich_cmgui/processed/smoke_train_rich.jsonl", "valid_file": "data/rich_cmgui/processed/smoke_valid_rich.jsonl", "output_dir": "runs/rich_cmgui_20260502/rich_grounded_siglip2_mt5_v1", "init_checkpoint": "", # 模型。V100 32GB x2 默认用 mT5-base;显存紧张再换 google/mt5-small。 "model_variant": "full", "vision_model": "models/siglip2-base-patch16-224", "decoder_model": "google/mt5-base", # 输入长度。384 + 纵向 crop 更适合中文小字;如果 SigLIP2 位置插值失败,先降到 224。 "image_size": 384, "num_vertical_crops": 3, # 0 keeps all visual patch tokens. Set a cap for high-resolution/crop runs # to borrow Pix2Struct-style richer screenshots without quadratic fusion blowup. "max_visual_tokens": 0, "max_elements": 80, "max_element_tokens": 16, "max_context_tokens": 64, "context_text_format": "rich", "context_include_screen_text": False, "context_screen_text_items": 32, "context_screen_text_dropout_rate": 0.0, # mean preserves old checkpoints; direct modes expose task/app tokens directly to cross-attention. "context_mode": "mean", # max_target_tokens 只截断训练标签;eval_max_new_tokens 只控制验证生成长度。 "max_target_tokens": 384, "eval_max_new_tokens": 384, # 训练。两张 V100 32GB: torchrun --nproc_per_node=2 train_rich.py "batch_size": 4, "eval_batch_size": 0, "grad_accum": 8, "epochs": 6, # 0 means use epochs. Set this higher for short early-stop runs so the LR # schedule matches the longer run whose early checkpoint is being reproduced. "scheduler_epochs": 0, "lr_new": 1e-4, "lr_fusion": 5e-5, "lr_decoder": 1e-5, "lr_ui_function_head": 0.0, "weight_decay": 0.01, "optimizer_name": "adamw", "lr_scheduler_type": "linear", "warmup_ratio": 0.05, "fp16": True, "amp_dtype": "auto", "generation_loss_chunk_size": 32, "activation_checkpointing": False, "cuda_empty_cache_steps": 0, "cuda_memory_fraction": 0.0, "decoder_gradient_checkpointing": False, "vision_gradient_checkpointing": False, "freeze_decoder": False, "freeze_vision": True, "unfreeze_vision_last_ratio": 0.3, # loss 权重。generation 是主目标,其他 loss 用来约束证据和结构。 "evidence_loss_weight": 0.2, "section_loss_weight": 0.1, "numeric_loss_weight": 0.1, "ui_function_loss_weight": 0.0, "search_function_loss_weight": 0.0, "search_function_pos_weight": 1.0, # checkpoint 和评估。 "save_every_steps": 1000, "save_checkpoints": True, "eval_every_steps": 0, "model_selection_metric": "rich_quality_score", "model_selection_mode": "max", "early_stopping_patience": 0, "early_stopping_min_delta": 0.0, "max_train_samples": 0, "max_valid_samples": 800, "num_beams": 4, "generation_no_repeat_ngram_size": 0, "generation_repetition_penalty": 1.0, "generation_min_new_tokens": 0, "generation_block_extra_ids": False, "generation_block_title_prefix": False, "generation_force_json_start": False, "context_summary_repair": False, "canonicalize_targets": False, "target_schema": "zh", "task_intent_context": False, "drop_bare_search_functions": False, "structured_function_mode": "decoder", "structured_function_threshold": 0.5, "structured_search_threshold": 0.5, "structured_max_functions": 12, "structured_strict_search_candidates": False, "structured_evidence_mode": "decoder", "structured_evidence_threshold": 0.5, "structured_max_evidence": 8, "structured_evidence_fallback_top1": True, # Vision-memory options. Defaults preserve old checkpoints/experiments. "direct_visual_tokens": False, "direct_element_tokens": False, "direct_context_passthrough": False, "include_pooled_memory": True, "native_context_forward": False, "disable_vision": False, "init_resize_mismatched_non_decoder": False, "grad_clip_strategy": "global", "max_grad_norm": 1.0, "function_signal_to_decoder": False, "function_signal_scale": 1.0, "search_signal_to_decoder": False, "search_signal_scale": 1.0, "visual_memory_scale": 1.0, "element_memory_scale": 1.0, "pooled_memory_scale": 1.0, "decoder_memory_scale": 1.0, "data_parallel": False, "strict_data_checks": True, "max_target_truncation_rate": 0.01, "seed": 20260502, "num_workers": 4, } SECTION_NAMES = ["visible_text", "interaction_data", "ui_functions", "key_ui_clues"] def read_jsonl(path: Path) -> Iterable[Dict[str, Any]]: with path.open("r", encoding="utf-8") as f: for line in f: line = line.strip() if line: yield json.loads(line) def write_json(path: Path, obj: Dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8") def append_jsonl(path: Path, obj: Dict[str, Any]) -> None: path.parent.mkdir(parents=True, exist_ok=True) with path.open("a", encoding="utf-8", newline="\n") as f: f.write(json.dumps(obj, ensure_ascii=False) + "\n") def set_seed(seed: int) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) def normalize_model_reference(value: Any) -> str: normalized = str(value or "").replace("\\", "/") if "://" not in normalized: normalized = re.sub(r"/+", "/", normalized) return normalized def init_distributed() -> Tuple[bool, int, int, int]: if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ: return False, 0, 1, 0 rank = int(os.environ["RANK"]) world_size = int(os.environ["WORLD_SIZE"]) local_rank = int(os.environ.get("LOCAL_RANK", 0)) dist.init_process_group(backend="nccl") torch.cuda.set_device(local_rank) return True, rank, world_size, local_rank def is_main(rank: int) -> bool: return rank == 0 def unwrap_parallel_model(model: nn.Module) -> nn.Module: if isinstance(model, (DistributedDataParallel, nn.DataParallel)): return model.module return model def safe_text(value: Any) -> str: if value is None: return "" return re.sub(r"\s+", " ", str(value)).strip() def target_to_text(target: Dict[str, Any], target_schema: str = "zh") -> str: schema = str(target_schema or "zh").lower() if target_schema_is_summary(schema): return safe_text(target.get("summary_zh")) if target_schema_is_summary_visible(schema): return target_to_summary_visible_text(target) if target_schema_is_natural_text(schema): return target_to_natural_text(target) if schema in {"alias", "aliases", "en", "english"}: payload = { "summary_zh": safe_text(target.get("summary_zh")), "visible_text": target.get("visible_text", [])[:16], "interaction_data": target.get("interaction_data", [])[:12], "ui_functions": target.get("ui_functions", [])[:12], "key_ui_clues": target.get("key_ui_clues", [])[:12], } else: payload = { "画面总结": safe_text(target.get("summary_zh")), "可见文字": target.get("visible_text", [])[:16], "互动数据": target.get("interaction_data", [])[:12], "功能入口": target.get("ui_functions", [])[:12], "关键证据": target.get("key_ui_clues", [])[:12], } return json.dumps(payload, ensure_ascii=False, separators=(",", ":")) def target_schema_is_summary(target_schema: str) -> bool: return str(target_schema or "").lower() in {"summary", "summary_zh", "summary-only", "summary_only"} def target_schema_is_natural_text(target_schema: str) -> bool: return str(target_schema or "").lower() in { "natural_zh", "rich_text_zh", "zh_text", "text_zh", "summary_visible_zh", "natural_summary_visible_zh", } def target_schema_is_summary_visible(target_schema: str) -> bool: return str(target_schema or "").lower() in {"summary_visible_zh", "natural_summary_visible_zh"} def format_evidence_suffix(values: Any) -> str: evidence_ids = dedupe_texts(values or [], max_items=8) if not evidence_ids: return "" return f"(证据:{'、'.join(evidence_ids)})" def format_natural_entry(value: Any) -> str: if isinstance(value, dict): name = safe_text(value.get("name") or value.get("text") or value.get("label")) detail = safe_text(value.get("value")) if name and detail: name = f"{name}={detail}" elif detail and not name: name = detail return safe_text(name + format_evidence_suffix(value.get("evidence_ids", []))) return safe_text(value) def format_natural_list(values: Any, max_items: int = 16) -> str: if not isinstance(values, list): return "无" entries = [] seen = set() for value in values[:max_items]: text = format_natural_entry(value) if text and text not in seen: entries.append(text) seen.add(text) return "、".join(entries) if entries else "无" def target_to_natural_text(target: Dict[str, Any]) -> str: summary = safe_text(target.get("summary_zh")) visible = format_natural_list(target.get("visible_text", []), max_items=16) interactions = format_natural_list(target.get("interaction_data", []), max_items=12) functions = format_natural_list(target.get("ui_functions", []), max_items=12) evidence = format_natural_list(target.get("key_ui_clues", []), max_items=12) return "\n".join( [ f"画面总结:{summary}", f"可见文字:{visible}", f"互动数据:{interactions}", f"功能入口:{functions}", f"关键证据:{evidence}", ] ) def target_to_summary_visible_text(target: Dict[str, Any]) -> str: summary = safe_text(target.get("summary_zh")) visible = format_natural_list(target.get("visible_text", []), max_items=16) interactions = format_natural_list(target.get("interaction_data", []), max_items=12) return "\n".join( [ f"画面总结:{summary}", f"可见文字:{visible}", f"互动数据:{interactions}", ] ) JSONISH_OUTPUT_KEYS = ("画面总结", "可见文字", "互动数据", "功能入口", "关键证据") JSONISH_ALL_KEYS = JSONISH_OUTPUT_KEYS + ("name", "value", "evidence_ids") def normalize_evidence_id(value: Any) -> str: text = safe_text(value) match = re.fullmatch(r"[Ee]\s*((?:\d\s*){2,8})", text) if match: digits = re.sub(r"\s+", "", match.group(1)) return f"E{digits}" return text def repair_generated_json_text(text: str) -> str: candidate = safe_text(text) start = candidate.find("{") if start >= 0: candidate = candidate[start:] candidate = re.sub(r"evidence\s*_\s*ids", "evidence_ids", candidate) candidate = re.sub( r"\b[Ee]\s*((?:\d\s*){2,8})\b", lambda match: "E" + re.sub(r"\s+", "", match.group(1)), candidate, ) for key in JSONISH_ALL_KEYS: candidate = re.sub( r"([\{\[,])\s*\"?\s*" + re.escape(key) + r"\s*\"\s*:", lambda match, key=key: f'{match.group(1)}"{key}":', candidate, ) candidate = re.sub( r"([\{\[,])\s*\"?\s*" + re.escape(key) + r"\s*:", lambda match, key=key: f'{match.group(1)}"{key}":', candidate, ) return candidate def json_object_from_text(text: str) -> Tuple[Optional[Dict[str, Any]], bool]: text = text.strip() if not text: return None, False try: obj = json.loads(text) return (obj, True) if isinstance(obj, dict) else (None, False) except json.JSONDecodeError: start = text.find("{") end = text.rfind("}") if start >= 0 and end > start: try: obj = json.loads(text[start : end + 1]) return (obj, True) if isinstance(obj, dict) else (None, False) except json.JSONDecodeError: pass return None, False def jsonish_field_payload(text: str, key: str, following_keys: Iterable[str]) -> str: match = re.search(r'"' + re.escape(key) + r'"\s*:', text) if not match: return "" tail = text[match.end() :] end = len(tail) for following_key in following_keys: next_match = re.search(r',?\s*"' + re.escape(following_key) + r'"\s*:', tail) if next_match: end = min(end, next_match.start()) return tail[:end].strip().rstrip(",") def clean_jsonish_string(value: Any) -> str: text = safe_text(value).strip() while text.startswith('"'): text = text[1:].strip() while text.endswith('"'): text = text[:-1].strip() return safe_text(text.replace('\\"', '"')) def dedupe_texts(values: Iterable[Any], max_items: int = 64) -> List[str]: output: List[str] = [] seen = set() for value in values: text = clean_jsonish_string(normalize_evidence_id(value)) if not text or text in seen: continue output.append(text) seen.add(text) if len(output) >= max_items: break return output def jsonish_string_list(payload: str, max_items: int = 64) -> List[str]: body = payload.strip() if "[" in body: body = body[body.find("[") + 1 :] if "]" in body: body = body[: body.rfind("]")] values = [match.group(1) for match in re.finditer(r'"([^"\\]*(?:\\.[^"\\]*)*)"', body)] if not values: values = re.findall(r"\bE\d{2,8}\b", body) return dedupe_texts(values, max_items=max_items) def normalize_function_entries(values: Any) -> List[Any]: if not isinstance(values, list): return [] output: List[Any] = [] for value in values: if isinstance(value, dict): name = clean_jsonish_string(value.get("name")) evidence_ids = dedupe_texts(value.get("evidence_ids", []) or [], max_items=8) if name: output.append({"name": name, "evidence_ids": evidence_ids}) else: name = clean_jsonish_string(value) if name: output.append(name) return output def jsonish_function_entries(payload: str) -> List[Any]: body = payload.strip() if body.startswith("[") and body.endswith("]"): try: parsed = json.loads(body) normalized = normalize_function_entries(parsed) if normalized: return normalized except json.JSONDecodeError: pass output: List[Any] = [] object_bodies = [match.group(1) for match in re.finditer(r"\{([^{}]*)\}", body)] or [body] for object_body in object_bodies: name_match = re.search(r'"name"\s*:\s*"([^"\\]*(?:\\.[^"\\]*)*)"', object_body) if not name_match: continue name = clean_jsonish_string(name_match.group(1)) evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", object_body), max_items=8) if name: output.append({"name": name, "evidence_ids": evidence_ids}) return output def jsonish_prediction_object(text: str) -> Optional[Dict[str, Any]]: repaired = repair_generated_json_text(text) summary = clean_jsonish_string( jsonish_field_payload(repaired, "画面总结", ("可见文字", "互动数据", "功能入口", "关键证据")) ) visible_text = jsonish_string_list( jsonish_field_payload(repaired, "可见文字", ("互动数据", "功能入口", "关键证据")), max_items=32, ) interaction_data = jsonish_string_list( jsonish_field_payload(repaired, "互动数据", ("功能入口", "关键证据")), max_items=16, ) functions = jsonish_function_entries(jsonish_field_payload(repaired, "功能入口", ("关键证据",))) evidence_ids = jsonish_string_list(jsonish_field_payload(repaired, "关键证据", ()), max_items=16) if not (summary or visible_text or interaction_data or functions or evidence_ids): return None return { "画面总结": summary, "可见文字": visible_text, "互动数据": interaction_data, "功能入口": functions, "关键证据": evidence_ids, } NATURAL_OUTPUT_KEYS = ("画面总结", "可见文字", "互动数据", "功能入口", "关键证据") def has_natural_title_prefix(text: str) -> bool: return re.search(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", str(text or "")) is not None def strip_natural_title_prefix(text: str) -> Tuple[str, bool]: raw = str(text or "").strip().replace("\r\n", "\n").replace("\r", "\n") if not raw or not has_natural_title_prefix(raw): return raw, False prefix_match = re.match(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", raw) if not prefix_match: return raw, False key_pattern = r"(" + "|".join(re.escape(key) for key in NATURAL_OUTPUT_KEYS) + r")\s*[::]" key_matches = list(re.finditer(key_pattern, raw)) if not key_matches: return re.sub(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", "", raw).strip(), True first_key = key_matches[0] prefix_body = safe_text(raw[prefix_match.end() : first_key.start()]) rest = raw[first_key.start() :].strip() if first_key.group(1) == "画面总结": return rest, True if prefix_body: return f"画面总结:{prefix_body}\n{rest}", True return rest, True def natural_field_payload(text: str, key: str, following_keys: Iterable[str]) -> str: match = re.search(re.escape(key) + r"\s*[::]", text) if not match: return "" tail = text[match.end() :] end = len(tail) for following_key in following_keys: next_match = re.search(re.escape(following_key) + r"\s*[::]", tail) if next_match: end = min(end, next_match.start()) return tail[:end].strip() def split_natural_items(payload: str, max_items: int = 64) -> List[str]: payload = safe_text(payload) if not payload or payload in {"无", "暂无", "没有", "[]"}: return [] parts = re.split(r"[\n;;|]+|(? List[Any]: output: List[Any] = [] for item in split_natural_items(payload, max_items=32): evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", normalize_evidence_id(item)), max_items=8) name = re.sub(r"(?证据[::].*?)?$", "", item).strip() name = re.sub(r"\(\s*证据[::].*?\)\s*$", "", name).strip() if name: output.append({"name": name, "evidence_ids": evidence_ids}) return output def natural_prediction_from_text(text: str) -> Dict[str, Any]: text, _ = strip_natural_title_prefix(text) fields = { key: natural_field_payload(text, key, NATURAL_OUTPUT_KEYS[index + 1 :]) for index, key in enumerate(NATURAL_OUTPUT_KEYS) } summary = safe_text(fields.get("画面总结")) if not summary: first_line = re.split(r"[\n。]", text, maxsplit=1)[0] summary = safe_text(first_line) visible_text = split_natural_items(fields.get("可见文字", ""), max_items=32) interaction_data = split_natural_items(fields.get("互动数据", ""), max_items=16) functions = natural_function_entries(fields.get("功能入口", "")) evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", normalize_evidence_id(fields.get("关键证据", ""))), max_items=16) if not evidence_ids: evidence_ids = dedupe_texts( eid for func in functions if isinstance(func, dict) for eid in func.get("evidence_ids", []) ) return { "画面总结": summary, "可见文字": visible_text, "互动数据": interaction_data, "功能入口": functions, "关键证据": evidence_ids, } def safe_json_loads_with_repair(text: str) -> Tuple[Optional[Dict[str, Any]], bool, bool, bool]: obj, strict_ok = json_object_from_text(text) if strict_ok: return obj, True, False, True repaired_text = repair_generated_json_text(text) obj, repaired_json_ok = json_object_from_text(repaired_text) if repaired_json_ok: return obj, True, repaired_text.strip() != str(text or "").strip(), False obj = jsonish_prediction_object(repaired_text) if obj is not None: return obj, True, True, False return None, False, False, False def safe_json_loads(text: str) -> Tuple[Optional[Dict[str, Any]], bool]: obj, ok, _, _ = safe_json_loads_with_repair(text) return obj, ok def load_seq2seq_tokenizer(model_name_or_path: str, model_hint: str = ""): name = f"{model_name_or_path} {model_hint}".lower() load_path = str(model_name_or_path) path = Path(load_path) has_spiece = path.is_dir() and (path / "spiece.model").exists() if has_spiece or "mt5" in name or "t5" in name: hint_path = Path(str(model_hint)) if model_hint else None if path.is_dir() and not (path / "spiece.model").exists() and hint_path and (hint_path / "spiece.model").exists(): load_path = str(hint_path) return T5Tokenizer.from_pretrained(load_path, fix_mistral_regex=True) return AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False) def char_lcs(a: str, b: str) -> int: if not a or not b: return 0 prev = [0] * (len(b) + 1) for ca in a: cur = [0] for j, cb in enumerate(b, start=1): if ca == cb: cur.append(prev[j - 1] + 1) else: cur.append(max(cur[-1], prev[j])) prev = cur return prev[-1] def rouge_l_char(pred: str, ref: str) -> float: if not pred or not ref: return 0.0 lcs = char_lcs(pred, ref) prec = lcs / max(1, len(pred)) rec = lcs / max(1, len(ref)) if prec + rec == 0: return 0.0 return 2 * prec * rec / (prec + rec) def extract_summary(obj: Optional[Dict[str, Any]]) -> str: if not obj: return "" return safe_text(obj.get("画面总结") or obj.get("summary_zh") or obj.get("summary")) def extract_evidence_ids(obj: Optional[Dict[str, Any]]) -> List[str]: if not obj: return [] values = obj.get("关键证据") or obj.get("key_ui_clues") or obj.get("evidence") or [] out: List[str] = [] if isinstance(values, list): for value in values: if isinstance(value, str): normalized = normalize_evidence_id(value) if normalized: out.append(normalized) elif isinstance(value, dict): out.extend(normalize_evidence_id(x) for x in value.get("evidence_ids", []) if normalize_evidence_id(x)) return out def extract_function_entries(obj: Optional[Dict[str, Any]]) -> List[Any]: if not isinstance(obj, dict): return [] values = obj.get("功能入口") or obj.get("ui_functions") or [] return values if isinstance(values, list) else [] def extract_function_names(obj: Optional[Dict[str, Any]]) -> List[str]: names = [] for value in extract_function_entries(obj): if isinstance(value, dict): names.append(safe_text(value.get("name"))) else: names.append(safe_text(value)) return [name for name in names if name] def extract_function_evidence_ids(target: Dict[str, Any]) -> List[str]: ids: List[str] = [] for value in extract_function_entries(target): if isinstance(value, dict): ids.extend(normalize_evidence_id(eid) for eid in value.get("evidence_ids", []) if normalize_evidence_id(eid)) return ids def extract_named_function_evidence_ids(target: Dict[str, Any], keyword: str) -> List[str]: ids: List[str] = [] for value in extract_function_entries(target): if isinstance(value, dict) and keyword in safe_text(value.get("name")): ids.extend(normalize_evidence_id(eid) for eid in value.get("evidence_ids", []) if normalize_evidence_id(eid)) return ids def has_search_function(obj: Optional[Dict[str, Any]]) -> bool: return any("搜索" in name for name in extract_function_names(obj)) def count_search_functions(obj: Optional[Dict[str, Any]]) -> int: return sum(1 for name in extract_function_names(obj) if "搜索" in name) def count_bare_search_functions(obj: Optional[Dict[str, Any]]) -> int: return sum(1 for name in extract_function_names(obj) if is_bare_search_function_name(name)) def normalized_prediction_obj(obj: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]: if not isinstance(obj, dict): return None return { "画面总结": safe_text(obj.get("画面总结") or obj.get("summary_zh") or obj.get("summary")), "可见文字": obj.get("可见文字") or obj.get("visible_text") or [], "互动数据": obj.get("互动数据") or obj.get("interaction_data") or [], "功能入口": obj.get("功能入口") or obj.get("ui_functions") or [], "关键证据": obj.get("关键证据") or obj.get("key_ui_clues") or obj.get("evidence") or [], } def element_id(item: Dict[str, Any]) -> str: return safe_text(item.get("id") or item.get("ocr_id")) def visible_text_from_row(row: Dict[str, Any], max_items: int = 16) -> List[str]: values: List[str] = [] seen = set() for group_name in ("ocr_items", "ui_items"): for item in row.get(group_name, []) or []: text = safe_text(item.get("text")) if not text or text in seen or re.fullmatch(r"E\d+", text): continue if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z0-9]", text): continue if re.fullmatch(r"[\s\-_.::/]+", text): continue values.append(text[:80]) seen.add(text) if len(values) >= max_items: return values return values def visible_texts_from_items(items: Iterable[Dict[str, Any]], max_items: int = 16) -> List[str]: values: List[str] = [] seen = set() for item in items or []: text = safe_text(item.get("text")) if not text or text in seen or re.fullmatch(r"E\d+", text): continue if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z0-9]", text): continue if re.fullmatch(r"[\s\-_.::/]+", text): continue values.append(text[:80]) seen.add(text) if len(values) >= max_items: break return values def maybe_dropout_texts(values: List[str], dropout_rate: float) -> List[str]: if dropout_rate <= 0.0 or not values: return values kept = [value for value in values if random.random() >= dropout_rate] if kept: return kept return [random.choice(values)] def build_context_text(row: Dict[str, Any], args: argparse.Namespace, screen_text_dropout_rate: float = 0.0) -> str: app_text = safe_text(row.get("app")) instruction_text = safe_text(row.get("instruction")) context_format = str(getattr(args, "context_text_format", "rich") or "rich").lower() if context_format == "text_only": parts = [f"app: {app_text}", f"task: {instruction_text}"] if bool(getattr(args, "context_include_screen_text", False)): max_screen_items = int(getattr(args, "context_screen_text_items", 32) or 32) ocr_text = visible_texts_from_items(row.get("ocr_items") or [], max_items=max_screen_items) ui_text = visible_texts_from_items(row.get("ui_items") or [], max_items=max_screen_items) ocr_text = maybe_dropout_texts(ocr_text, screen_text_dropout_rate) ui_text = maybe_dropout_texts(ui_text, screen_text_dropout_rate) if ocr_text: parts.append("ocr: " + " | ".join(ocr_text)) if ui_text: parts.append("ui: " + " | ".join(ui_text)) return "\n".join(parts) context_text = f"应用:{app_text} 任务:{instruction_text}" if bool(getattr(args, "context_include_screen_text", False)): screen_text = visible_text_from_row(row, max_items=int(getattr(args, "context_screen_text_items", 32) or 32)) screen_text = maybe_dropout_texts(screen_text, screen_text_dropout_rate) if screen_text: context_text = f"{context_text} 屏幕文字:{' | '.join(screen_text)}" return context_text def prediction_from_summary(row: Dict[str, Any], summary_text: str) -> Dict[str, Any]: # decoder 在 summary 模式下只生成 "画面总结",其它字段必须留空, # 任何来自输入行(OCR/UI 项)的回填都会污染评估指标,让 rouge 之外的字段失真。 # 结构化字段如需填充,必须由后续 apply_structured_*_predictions(来自辅助 head)显式写入。 return { "画面总结": safe_text(summary_text), "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": [], } def repair_prediction_with_context(row: Dict[str, Any], pred_obj: Optional[Dict[str, Any]]) -> Tuple[Dict[str, Any], bool]: app = safe_text(row.get("app")) or "移动应用" instruction = safe_text(row.get("instruction")) has_search_task = row_has_search_task(row) has_search_evidence = row_has_visible_search_evidence(row) repaired = normalized_prediction_obj(pred_obj) changed = repaired is None if repaired is None: repaired = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []} summary = safe_text(repaired.get("画面总结")) search_overuse = ("搜索页面" in summary or "搜索结果" in summary) and "搜索" not in instruction generic_app = "App的App" in summary or "手机App的App" in summary or summary.count("App") >= 3 needs_summary = not summary or app not in summary or generic_app or search_overuse if needs_summary: if instruction: summary = f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。 当前任务语境是:{instruction}。" else: summary = f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。" repaired["画面总结"] = summary changed = True visible = repaired.get("可见文字") or [] if isinstance(visible, list): cleaned_visible = [value for value in visible if not re.fullmatch(r"E\d+", safe_text(value))][:16] if cleaned_visible != visible: repaired["可见文字"] = cleaned_visible changed = True else: repaired["可见文字"] = [] changed = True if not isinstance(repaired.get("互动数据"), list): repaired["互动数据"] = [] changed = True functions = repaired.get("功能入口") or [] if isinstance(functions, list): cleaned_functions = [] for function in functions: name = safe_text(function.get("name") if isinstance(function, dict) else function) if is_generic_function_name(name, row): changed = True continue if "搜索" in name and not (has_search_task and has_search_evidence): changed = True continue cleaned_functions.append(function) if cleaned_functions != functions: repaired["功能入口"] = cleaned_functions else: repaired["功能入口"] = [] changed = True evidence = repaired.get("关键证据") or [] if not isinstance(evidence, list): evidence = [] if not evidence: evidence = list(row.get("weak_evidence_ids") or [])[:8] repaired["关键证据"] = evidence changed = True return repaired, changed def row_has_visible_search_evidence(row: Dict[str, Any]) -> bool: for item in list(row.get("ui_items", [])) + list(row.get("ocr_items", [])): text = safe_text(item.get("text")) item_type = safe_text(item.get("type")).lower() if "搜索" in text or "search" in item_type: return True if text in {"搜", "搜索框"} or "输入" in text: return True return False def row_has_search_context(row: Dict[str, Any]) -> bool: instruction = safe_text(row.get("instruction")) screen_text = "".join(safe_text(item.get("text")) for item in (row.get("ui_items") or [])) screen_text += "".join(safe_text(item.get("text")) for item in (row.get("ocr_items") or [])) return "搜索" in instruction or "搜索" in screen_text def row_has_search_task(row: Dict[str, Any]) -> bool: instruction = safe_text(row.get("instruction")) return "搜索" in instruction or "查找" in instruction or "搜" in instruction def is_generic_function_name(name: str, row: Dict[str, Any]) -> bool: text = safe_text(name) app = safe_text(row.get("app")) if not text: return True generic_names = { "入口", "功能入口", "功能反馈", "打开应用", "进入应用", "打开App", "进入App", } if text in generic_names: return True if app and text in {f"{app}入口", f"进入{app}", f"打开{app}", f"进入{app}App", f"打开{app}App"}: return True if re.fullmatch(r"(进入|打开).{1,12}(App|应用)?", text): return True return False def is_bare_search_function_name(name: str) -> bool: return safe_text(name) == "搜索" def is_search_ui_item(item: Dict[str, Any]) -> bool: text = safe_text(item.get("text")) item_type = safe_text(item.get("type")).lower() if "搜索" in text or "search" in item_type: return True return text in {"搜", "搜索框"} or "输入" in text def is_structured_search_candidate_item(item: Dict[str, Any], strict: bool = False) -> bool: if not is_search_ui_item(item): return False if not strict: return True text = safe_text(item.get("text")) if text in {"搜索", "搜", "搜索框"}: return True passive_terms = [ "历史搜索", "搜索历史", "搜索发现", "深度搜索", "AI搜索", "热门搜索", "最近搜索", "相关搜索", "搜索记录", "清除搜索记录", "搜索结果", "搜索来源", "搜索推荐", "根据你的搜索", ] if any(term in text for term in passive_terms): return False control_terms = [ "搜索框", "搜索、提问", "搜索地点", "搜索店内", "搜索商品", "搜索景点", "搜索机票", "附近搜索", "请输入", "输入", ] if any(term in text for term in control_terms): return True return text.startswith("搜索") and len(text) <= 8 def structured_function_name(item: Dict[str, Any], row: Dict[str, Any]) -> str: text = safe_text(item.get("text")) item_type = safe_text(item.get("type")).lower() if is_search_ui_item(item): return "搜索功能入口" if text: if len(text) <= 12: return text if "购物车" in text: return "购物车" if "消息" in text: return "消息入口" if "设置" in text: return "设置入口" return text[:12] if item.get("is_action_target"): instruction = safe_text(row.get("instruction")) for keyword in ["商城", "购物车", "消息", "订单", "收藏", "关注", "分享", "返回"]: if keyword in instruction: return f"{keyword}入口" if keyword not in {"收藏", "关注", "分享", "返回"} else keyword if "button" in item_type: return "按钮入口" return "功能入口" def apply_structured_function_predictions( row: Dict[str, Any], pred_obj: Optional[Dict[str, Any]], function_scores: torch.Tensor, search_scores: torch.Tensor, args: argparse.Namespace, ) -> Dict[str, Any]: mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower() if mode == "decoder": return pred_obj if isinstance(pred_obj, dict) else normalized_prediction_obj(pred_obj) structured = normalized_prediction_obj(pred_obj) if structured is None: structured = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []} function_threshold = float(getattr(args, "structured_function_threshold", 0.5) or 0.5) search_threshold = float(getattr(args, "structured_search_threshold", function_threshold) or function_threshold) max_functions = int(getattr(args, "structured_max_functions", 12) or 12) has_search_task = row_has_search_task(row) has_search_evidence = row_has_visible_search_evidence(row) strict_search_candidates = bool(getattr(args, "structured_strict_search_candidates", False)) items = (row.get("ui_items") or [])[: int(getattr(args, "max_elements", len(function_scores)))] candidates: List[Tuple[float, Dict[str, Any]]] = [] for idx, item in enumerate(items[: len(function_scores)]): function_score = float(function_scores[idx]) search_score = float(search_scores[idx]) if idx < len(search_scores) else 0.0 raw_item_is_search = is_search_ui_item(item) item_is_search = is_structured_search_candidate_item(item, strict=strict_search_candidates) if raw_item_is_search and not item_is_search: continue if item_is_search: keep = has_search_task and has_search_evidence and search_score >= search_threshold else: keep = function_score >= function_threshold if not keep: continue evidence_id = safe_text(item.get("id") or item.get("ocr_id")) if not evidence_id: continue name = structured_function_name(item, row) if is_generic_function_name(name, row): continue if "搜索" in name and not (has_search_task and has_search_evidence): continue score = search_score if item_is_search else function_score candidates.append((score, {"name": name, "evidence_ids": [evidence_id]})) candidates.sort(key=lambda value: value[0], reverse=True) functions: List[Dict[str, Any]] = [] seen = set() for _, function in candidates: key = (function["name"], tuple(function.get("evidence_ids", []))) if key in seen: continue seen.add(key) functions.append(function) if len(functions) >= max_functions: break structured["功能入口"] = functions return structured def apply_structured_evidence_predictions( row: Dict[str, Any], pred_obj: Optional[Dict[str, Any]], evidence_scores: torch.Tensor, args: argparse.Namespace, ) -> Dict[str, Any]: mode = str(getattr(args, "structured_evidence_mode", "decoder") or "decoder").lower() structured = normalized_prediction_obj(pred_obj) if structured is None: structured = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []} if mode == "decoder": return structured if mode != "heads": raise ValueError("structured_evidence_mode must be one of: decoder, heads") threshold = float(getattr(args, "structured_evidence_threshold", 0.5) or 0.5) max_evidence = int(getattr(args, "structured_max_evidence", 8) or 8) fallback_top1 = bool(getattr(args, "structured_evidence_fallback_top1", True)) items = (row.get("ui_items") or [])[: int(getattr(args, "max_elements", len(evidence_scores)))] candidates: List[Tuple[float, str]] = [] for idx, item in enumerate(items[: len(evidence_scores)]): eid = element_id(item) if not eid: continue score = float(evidence_scores[idx]) candidates.append((score, eid)) candidates.sort(key=lambda value: value[0], reverse=True) selected: List[str] = [] seen = set() for score, eid in candidates: if score < threshold: continue if eid in seen: continue selected.append(eid) seen.add(eid) if len(selected) >= max_evidence: break if not selected and fallback_top1 and candidates: selected.append(candidates[0][1]) structured["关键证据"] = selected return structured def build_context_summary(row: Dict[str, Any]) -> str: app = safe_text(row.get("app")) or "移动应用" instruction = safe_text(row.get("instruction")) if instruction: return f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。 当前任务语境是:{instruction}。" return f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。" def canonicalize_target_with_context( row: Dict[str, Any], target: Dict[str, Any], drop_bare_search_functions: bool = False, ) -> Dict[str, Any]: canonical = dict(target or {}) # 只在 summary_zh 缺失时回填模板,绝对不要覆盖已有标签—— # 之前无条件覆盖会把任何真实摘要替换成 "这是一个X界面...当前任务语境是..." 模板, # 是导致全量训练目标 100% 同质化、模型只学到模板的根因之一。 existing_summary = safe_text(canonical.get("summary_zh")) if not existing_summary: canonical["summary_zh"] = build_context_summary(row) visible = canonical.get("visible_text") or [] if isinstance(visible, list): canonical["visible_text"] = [value for value in visible if not re.fullmatch(r"E\d+", safe_text(value))][:16] else: canonical["visible_text"] = [] functions = canonical.get("ui_functions") or [] if isinstance(functions, list): has_search_task = row_has_search_task(row) has_search_evidence = row_has_visible_search_evidence(row) cleaned_functions = [] for function in functions: name = safe_text(function.get("name") if isinstance(function, dict) else function) if is_generic_function_name(name, row): continue if drop_bare_search_functions and is_bare_search_function_name(name): continue if "搜索" in name and not (has_search_task and has_search_evidence): continue cleaned_functions.append(function) canonical["ui_functions"] = cleaned_functions[:12] else: canonical["ui_functions"] = [] if not isinstance(canonical.get("interaction_data"), list): canonical["interaction_data"] = [] evidence = canonical.get("key_ui_clues") or [] if not isinstance(evidence, list): evidence = [] if not evidence: evidence = list(row.get("weak_evidence_ids") or [])[:8] canonical["key_ui_clues"] = evidence[:12] return canonical class RichScreenshotDataset(Dataset): def __init__(self, path: str, max_samples: int = 0, sample_seed: Optional[int] = None): self.path = Path(path) self.rows = list(read_jsonl(self.path)) if max_samples and max_samples < len(self.rows): if sample_seed is None: self.rows = self.rows[:max_samples] else: rng = random.Random(sample_seed) indices = sorted(rng.sample(range(len(self.rows)), max_samples)) self.rows = [self.rows[index] for index in indices] if not self.rows: raise ValueError(f"No rows loaded from {path}") def __len__(self) -> int: return len(self.rows) def __getitem__(self, idx: int) -> Dict[str, Any]: return self.rows[idx] def dataset_diagnostics(rows: List[Dict[str, Any]]) -> Dict[str, Any]: ocr_counts: List[int] = [] ui_counts: List[int] = [] target_chars: List[int] = [] missing_target = 0 short_summary = 0 missing_image_path = 0 missing_image_file = 0 missing_ocr_items = 0 missing_ui_items = 0 for row in rows: target = row.get("target") if not isinstance(target, dict): missing_target += 1 summary = "" else: summary = safe_text(target.get("summary_zh")) target_chars.append(len(summary)) if len(summary) < 8: short_summary += 1 image_path = safe_text(row.get("image_path")) if not image_path: missing_image_path += 1 elif not Path(image_path).exists(): missing_image_file += 1 ocr_count = len(row.get("ocr_items") or []) ui_count = len(row.get("ui_items") or []) ocr_counts.append(ocr_count) ui_counts.append(ui_count) if ocr_count <= 0: missing_ocr_items += 1 if ui_count <= 0: missing_ui_items += 1 row_count = len(rows) def mean(values: List[int]) -> float: return float(np.mean(values)) if values else 0.0 return { "rows": row_count, "missing_target": missing_target, "short_summary": short_summary, "missing_image_path": missing_image_path, "missing_image_file": missing_image_file, "missing_ocr_items": missing_ocr_items, "missing_ui_items": missing_ui_items, "missing_ocr_rate": missing_ocr_items / max(1, row_count), "missing_ui_rate": missing_ui_items / max(1, row_count), "ocr_items_mean": mean(ocr_counts), "ocr_items_max": int(max(ocr_counts)) if ocr_counts else 0, "ui_items_mean": mean(ui_counts), "ui_items_max": int(max(ui_counts)) if ui_counts else 0, "summary_chars_mean": mean(target_chars), "summary_chars_max": int(max(target_chars)) if target_chars else 0, } def validate_dataset_for_training(split_name: str, diagnostics: Dict[str, Any], args: argparse.Namespace) -> None: if not bool(getattr(args, "strict_data_checks", True)): return errors: List[str] = [] if diagnostics["missing_target"]: errors.append(f"{diagnostics['missing_target']} rows are missing target") if diagnostics["short_summary"]: errors.append(f"{diagnostics['short_summary']} rows have summary_zh shorter than 8 chars") if diagnostics["missing_image_path"]: errors.append(f"{diagnostics['missing_image_path']} rows are missing image_path") vision_enabled = not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only" if vision_enabled and diagnostics["missing_image_file"]: errors.append(f"{diagnostics['missing_image_file']} image files do not exist") needs_ui_context = bool(getattr(args, "direct_element_tokens", False)) or int(getattr(args, "max_elements", 0) or 0) > 0 needs_screen_text = bool(getattr(args, "context_include_screen_text", False)) if needs_ui_context and diagnostics["missing_ui_rate"] > 0.05: errors.append(f"{diagnostics['missing_ui_items']} rows are missing ui_items") if needs_screen_text and diagnostics["missing_ocr_rate"] > 0.05: errors.append(f"{diagnostics['missing_ocr_items']} rows are missing ocr_items") if errors: raise ValueError(f"{split_name} dataset failed strict data checks: " + "; ".join(errors)) def token_length_summary(values: List[int], max_length: int) -> Dict[str, Any]: if not values: return { "count": 0, "mean": 0.0, "p50": 0, "p90": 0, "p95": 0, "p99": 0, "max": 0, "over_max": 0, "over_max_rate": 0.0, "at_or_over_max": 0, "at_or_over_max_rate": 0.0, } ordered = sorted(values) def pct(percent: float) -> int: index = int(round((len(ordered) - 1) * percent / 100.0)) return int(ordered[max(0, min(len(ordered) - 1, index))]) over_max = sum(value > max_length for value in values) at_or_over_max = sum(value >= max_length for value in values) return { "count": len(values), "mean": float(np.mean(values)), "p50": pct(50), "p90": pct(90), "p95": pct(95), "p99": pct(99), "max": int(max(values)), "configured_max": int(max_length), "over_max": int(over_max), "over_max_rate": over_max / max(1, len(values)), "at_or_over_max": int(at_or_over_max), "at_or_over_max_rate": at_or_over_max / max(1, len(values)), } def tokenizer_diagnostics(rows: List[Dict[str, Any]], tokenizer, args: argparse.Namespace) -> Dict[str, Any]: target_lengths: List[int] = [] context_lengths: List[int] = [] target_schema = getattr(args, "target_schema", "zh") max_target_tokens = int(getattr(args, "max_target_tokens", 384) or 384) max_context_tokens = int(getattr(args, "max_context_tokens", 64) or 64) for row in rows: target = row.get("target") or {} if bool(getattr(args, "canonicalize_targets", False)): target = canonicalize_target_with_context( row, target, drop_bare_search_functions=bool(getattr(args, "drop_bare_search_functions", False)), ) target_text = target_to_text(target, target_schema) context_text = build_context_text(row, args) target_lengths.append(len(tokenizer(target_text, add_special_tokens=True).input_ids)) context_lengths.append(len(tokenizer(context_text, add_special_tokens=True).input_ids)) return { "target_tokens": token_length_summary(target_lengths, max_target_tokens), "context_tokens": token_length_summary(context_lengths, max_context_tokens), } def validate_token_lengths(split_name: str, diagnostics: Dict[str, Any], args: argparse.Namespace) -> None: if not bool(getattr(args, "strict_data_checks", True)): return max_allowed_rate = float(getattr(args, "max_target_truncation_rate", 0.01) or 0.0) target_stats = diagnostics.get("target_tokens") or {} over_rate = float(target_stats.get("over_max_rate", 0.0) or 0.0) if over_rate > max_allowed_rate: raise ValueError( f"{split_name} target token truncation rate {over_rate:.2%} exceeds " f"max_target_truncation_rate={max_allowed_rate:.2%}; " f"max_target_tokens={target_stats.get('configured_max')}, " f"p95={target_stats.get('p95')}, p99={target_stats.get('p99')}, max={target_stats.get('max')}" ) if int(getattr(args, "eval_max_new_tokens", 0) or 0) < int(getattr(args, "max_target_tokens", 0) or 0): raise ValueError("eval_max_new_tokens must be >= max_target_tokens for natural_zh validation.") class RichCollator: def __init__(self, tokenizer, image_processor, args: argparse.Namespace, is_training: bool = False): self.tokenizer = tokenizer self.image_processor = image_processor self.args = args self.is_training = is_training self.use_vision = not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only" def _load_images(self, image_path: str) -> List[Image.Image]: path = Path(image_path) try: image = Image.open(path).convert("RGB") except Exception: image = Image.new("RGB", (self.args.image_size, self.args.image_size), color=(245, 245, 245)) images = [image] if self.args.num_vertical_crops > 0: width, height = image.size crops = self.args.num_vertical_crops for i in range(crops): top = int(height * i / crops) bottom = int(height * (i + 1) / crops) images.append(image.crop((0, top, width, bottom))) return images def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: bsz = len(batch) if self.use_vision: flat_images: List[Image.Image] = [] crop_counts: List[int] = [] for row in batch: images = self._load_images(row.get("image_path", "")) crop_counts.append(len(images)) flat_images.extend(images) image_kwargs = { "return_tensors": "pt", "do_resize": True, "size": {"height": self.args.image_size, "width": self.args.image_size}, } pixel_values = self.image_processor(images=flat_images, **image_kwargs)["pixel_values"] crops = crop_counts[0] pixel_values = pixel_values.view(bsz, crops, *pixel_values.shape[1:]) else: pixel_values = torch.empty(bsz, 0, 3, self.args.image_size, self.args.image_size, dtype=torch.float32) all_elements: List[List[Dict[str, Any]]] = [] flat_element_texts: List[str] = [] element_mask = torch.zeros(bsz, self.args.max_elements, dtype=torch.bool) bboxes = torch.zeros(bsz, self.args.max_elements, 4, dtype=torch.float32) type_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long) source_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long) loc_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long) confs = torch.zeros(bsz, self.args.max_elements, 1, dtype=torch.float32) action_flags = torch.zeros(bsz, self.args.max_elements, 1, dtype=torch.float32) evidence_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32) numeric_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32) ui_function_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32) search_function_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32) type_vocab: Dict[str, int] = {"visual": 0, "text": 1, "button": 2, "label": 3, "text_number": 4} source_vocab: Dict[str, int] = {"unknown": 0, "cmgui": 1, "ocr": 2, "rule": 3} loc_vocab: Dict[str, int] = { "unknown": 0, "top-left": 1, "top-center": 2, "top-right": 3, "middle-left": 4, "middle-center": 5, "middle-right": 6, "bottom-left": 7, "bottom-center": 8, "bottom-right": 9, } target_texts: List[str] = [] context_texts: List[str] = [] section_labels = torch.zeros(bsz, len(SECTION_NAMES), dtype=torch.float32) for row_idx, row in enumerate(batch): target = row.get("target") or {} if bool(getattr(self.args, "canonicalize_targets", False)): target = canonicalize_target_with_context( row, target, drop_bare_search_functions=bool(getattr(self.args, "drop_bare_search_functions", False)), ) target_texts.append(target_to_text(target, getattr(self.args, "target_schema", "zh"))) screen_text_dropout_rate = ( float(getattr(self.args, "context_screen_text_dropout_rate", 0.0) or 0.0) if self.is_training else 0.0 ) context_texts.append(build_context_text(row, self.args, screen_text_dropout_rate=screen_text_dropout_rate)) evidence_ids = set(target.get("key_ui_clues", []) or row.get("weak_evidence_ids", [])) function_evidence = set(extract_function_evidence_ids(target)) search_function_evidence = set(extract_named_function_evidence_ids(target, "搜索")) interaction_evidence = set() for interaction in target.get("interaction_data", []) or []: interaction_evidence.update(interaction.get("evidence_ids", []) or []) for sec_idx, name in enumerate(SECTION_NAMES): if target.get(name): section_labels[row_idx, sec_idx] = 1.0 if bool(getattr(self.args, "task_intent_context", False)): task_type = "检索" if row_has_search_task(row) else "常规" context_texts[-1] = f"{context_texts[-1]} 任务类别:{task_type}" elements = (row.get("ui_items") or [])[: self.args.max_elements] all_elements.append(elements) for elem_idx in range(self.args.max_elements): if elem_idx < len(elements): elem = elements[elem_idx] text = safe_text(elem.get("text")) flat_element_texts.append(f"{elem.get('type','text')} {elem.get('source','')} {text}") element_mask[row_idx, elem_idx] = True bbox = elem.get("bbox") or [0, 0, 1, 1] if len(bbox) == 4: bboxes[row_idx, elem_idx] = torch.tensor(bbox, dtype=torch.float32) type_ids[row_idx, elem_idx] = type_vocab.get(safe_text(elem.get("type")), 1) source_ids[row_idx, elem_idx] = source_vocab.get(safe_text(elem.get("source")), 0) loc_ids[row_idx, elem_idx] = loc_vocab.get(safe_text(elem.get("location")), 0) confs[row_idx, elem_idx, 0] = float(elem.get("ocr_conf", elem.get("conf", 1.0)) or 1.0) action_flags[row_idx, elem_idx, 0] = 1.0 if elem.get("is_action_target") else 0.0 if elem.get("id") in evidence_ids or elem.get("ocr_id") in evidence_ids: evidence_labels[row_idx, elem_idx] = 1.0 if elem.get("id") in function_evidence or elem.get("ocr_id") in function_evidence: ui_function_labels[row_idx, elem_idx] = 1.0 if elem.get("id") in search_function_evidence or elem.get("ocr_id") in search_function_evidence: search_function_labels[row_idx, elem_idx] = 1.0 if elem.get("id") in interaction_evidence or re.search(r"\d", text): numeric_labels[row_idx, elem_idx] = 1.0 else: flat_element_texts.append("") elem_tokens = self.tokenizer( flat_element_texts, padding=True, truncation=True, max_length=self.args.max_element_tokens, return_tensors="pt", ) elem_input_ids = elem_tokens.input_ids.view(bsz, self.args.max_elements, -1) elem_attention_mask = elem_tokens.attention_mask.view(bsz, self.args.max_elements, -1) context_tokens = self.tokenizer( context_texts, padding=True, truncation=True, max_length=self.args.max_context_tokens, return_tensors="pt", ) target_tokens = self.tokenizer( target_texts, padding=True, truncation=True, max_length=self.args.max_target_tokens, return_tensors="pt", ) labels = target_tokens.input_ids labels[labels == self.tokenizer.pad_token_id] = -100 return { "rows": batch, "pixel_values": pixel_values, "element_input_ids": elem_input_ids, "element_attention_mask": elem_attention_mask, "element_mask": element_mask, "bboxes": bboxes, "type_ids": type_ids, "source_ids": source_ids, "loc_ids": loc_ids, "confs": confs, "action_flags": action_flags, "context_input_ids": context_tokens.input_ids, "context_attention_mask": context_tokens.attention_mask, "labels": labels, "evidence_labels": evidence_labels, "ui_function_labels": ui_function_labels, "search_function_labels": search_function_labels, "section_labels": section_labels, "numeric_labels": numeric_labels, } class BottleneckPooler(nn.Module): def __init__(self, hidden_size: int, num_queries: int = 64, num_heads: int = 8): super().__init__() self.queries = nn.Parameter(torch.randn(num_queries, hidden_size) * 0.02) self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True) self.norm = nn.LayerNorm(hidden_size) def forward(self, memory: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor: bsz = memory.size(0) queries = self.queries.unsqueeze(0).expand(bsz, -1, -1) pooled, _ = self.attn(queries, memory, memory, key_padding_mask=key_padding_mask) return self.norm(pooled + queries) class RichGroundedModel(nn.Module): def __init__(self, args: argparse.Namespace): super().__init__() self.args = args self.model_variant = args.model_variant self.use_vision = not bool(getattr(args, "disable_vision", False)) and self.model_variant != "annotation_only" self.vision = AutoModel.from_pretrained(args.vision_model) if self.use_vision else None self.decoder = AutoModelForSeq2SeqLM.from_pretrained(args.decoder_model) if bool(getattr(args, "decoder_gradient_checkpointing", False)) and hasattr(self.decoder, "gradient_checkpointing_enable"): self.decoder.gradient_checkpointing_enable() if hasattr(self.decoder.config, "use_cache"): self.decoder.config.use_cache = False if bool(getattr(args, "freeze_decoder", False)): for param in self.decoder.parameters(): param.requires_grad = False if self.vision is not None and bool(getattr(args, "vision_gradient_checkpointing", False)) and hasattr(self.vision, "gradient_checkpointing_enable"): self.vision.gradient_checkpointing_enable() self.hidden_size = self.decoder.config.d_model vision_hidden = self._vision_hidden_size() if self.vision is not None else self.hidden_size self.visual_proj = nn.Linear(vision_hidden, self.hidden_size) self.elem_text_proj = nn.Linear(self.hidden_size, self.hidden_size) self.type_emb = nn.Embedding(16, self.hidden_size) self.source_emb = nn.Embedding(8, self.hidden_size) self.loc_emb = nn.Embedding(16, self.hidden_size) self.bbox_proj = nn.Sequential(nn.Linear(6, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, self.hidden_size)) self.roi_proj = nn.Linear(vision_hidden, self.hidden_size) self.context_proj = nn.Linear(self.hidden_size, self.hidden_size) enc_layer = nn.TransformerEncoderLayer( d_model=self.hidden_size, nhead=8, dim_feedforward=self.hidden_size * 4, dropout=0.1, batch_first=True, norm_first=False, ) self.layout_encoder = nn.TransformerEncoder(enc_layer, num_layers=2) fusion_layer = nn.TransformerEncoderLayer( d_model=self.hidden_size, nhead=8, dim_feedforward=self.hidden_size * 4, dropout=0.1, batch_first=True, norm_first=False, ) self.fusion = nn.TransformerEncoder(fusion_layer, num_layers=2) self.pooler = BottleneckPooler(self.hidden_size, num_queries=args.bottleneck_queries, num_heads=8) self.evidence_head = nn.Linear(self.hidden_size, 1) self.ui_function_head = nn.Linear(self.hidden_size, 1) self.search_function_head = nn.Linear(self.hidden_size, 1) self.function_signal_proj = nn.Linear(1, self.hidden_size) self.search_signal_proj = nn.Linear(1, self.hidden_size) nn.init.zeros_(self.function_signal_proj.weight) nn.init.zeros_(self.function_signal_proj.bias) nn.init.zeros_(self.search_signal_proj.weight) nn.init.zeros_(self.search_signal_proj.bias) self.numeric_head = nn.Linear(self.hidden_size, 1) self.section_head = nn.Linear(self.hidden_size, len(SECTION_NAMES)) if self.vision is not None and args.freeze_vision: self.freeze_vision(args.unfreeze_vision_last_ratio) def _vision_hidden_size(self) -> int: if self.vision is None: return self.hidden_size config = self.vision.config if hasattr(config, "vision_config"): return int(config.vision_config.hidden_size) return int(getattr(config, "hidden_size")) def freeze_vision(self, unfreeze_last_ratio: float) -> None: if self.vision is None: return for param in self.vision.parameters(): param.requires_grad = False if unfreeze_last_ratio <= 0: return layers = None vision_model = getattr(self.vision, "vision_model", self.vision) encoder = getattr(vision_model, "encoder", None) if encoder is not None: layers = getattr(encoder, "layers", None) or getattr(encoder, "layer", None) if layers: keep = max(1, int(len(layers) * unfreeze_last_ratio)) for layer in layers[-keep:]: for param in layer.parameters(): param.requires_grad = True def use_native_context_forward(self) -> bool: return bool(getattr(self.args, "native_context_forward", False)) def mean_embed(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: emb = self.decoder.get_input_embeddings()(input_ids) mask = attention_mask.unsqueeze(-1).float() return (emb * mask).sum(dim=-2) / mask.sum(dim=-2).clamp_min(1.0) def encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor: bsz, crops, channels, height, width = pixel_values.shape flat = pixel_values.view(bsz * crops, channels, height, width) vision_model = getattr(self.vision, "vision_model", self.vision) vision_trainable = any(param.requires_grad for param in self.vision.parameters()) with torch.set_grad_enabled(vision_trainable): try: out = vision_model(pixel_values=flat, interpolate_pos_encoding=True) except TypeError: out = vision_model(pixel_values=flat) tokens = out.last_hidden_state tokens = tokens.view(bsz, crops, tokens.size(1), tokens.size(2)) return tokens def reduce_visual_tokens(self, visual: torch.Tensor) -> torch.Tensor: bsz, crops, num_tokens, hidden = visual.shape max_visual_tokens = int(getattr(self.args, "max_visual_tokens", 0) or 0) if max_visual_tokens <= 0 or crops * num_tokens <= max_visual_tokens: return self.visual_proj(visual.flatten(1, 2)) per_crop = max(1, max_visual_tokens // max(1, crops)) reduced_crops: List[torch.Tensor] = [] for crop_idx in range(crops): crop_tokens = visual[:, crop_idx, :, :] grid = int(math.sqrt(num_tokens)) cls_token = None patch_tokens = crop_tokens if grid * grid != num_tokens: maybe_grid = int(math.sqrt(num_tokens - 1)) if maybe_grid * maybe_grid == num_tokens - 1: cls_token = crop_tokens[:, :1, :] patch_tokens = crop_tokens[:, 1:, :] grid = maybe_grid else: target = min(per_crop, num_tokens) indices = torch.linspace(0, num_tokens - 1, steps=target, device=visual.device).round().long() reduced_crops.append(crop_tokens.index_select(1, indices)) continue cls_budget = 1 if cls_token is not None and per_crop > 1 else 0 patch_budget = max(1, per_crop - cls_budget) out_grid = max(1, int(math.sqrt(patch_budget))) patch_tokens = patch_tokens.view(bsz, grid, grid, hidden).permute(0, 3, 1, 2) pooled = F.adaptive_avg_pool2d(patch_tokens, (out_grid, out_grid)) pooled = pooled.permute(0, 2, 3, 1).reshape(bsz, out_grid * out_grid, hidden) if cls_budget: pooled = torch.cat([cls_token, pooled], dim=1) reduced_crops.append(pooled) return self.visual_proj(torch.cat(reduced_crops, dim=1)) def roi_pool(self, full_tokens: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor: bsz, num_tokens, hidden = full_tokens.shape grid = int(math.sqrt(num_tokens)) if grid * grid != num_tokens: patch_tokens = full_tokens grid = int(math.sqrt(num_tokens - 1)) if grid * grid == num_tokens - 1: patch_tokens = full_tokens[:, 1:, :] else: pooled = full_tokens.mean(dim=1, keepdim=True).expand(-1, bboxes.size(1), -1) return pooled else: patch_tokens = full_tokens patch_tokens = patch_tokens.view(bsz, grid, grid, hidden) outputs = [] for b in range(bsz): elem_vecs = [] for bbox in bboxes[b]: x1, y1, x2, y2 = bbox.tolist() ix1 = max(0, min(grid - 1, int(x1 * grid))) iy1 = max(0, min(grid - 1, int(y1 * grid))) ix2 = max(ix1 + 1, min(grid, int(math.ceil(x2 * grid)))) iy2 = max(iy1 + 1, min(grid, int(math.ceil(y2 * grid)))) region = patch_tokens[b, iy1:iy2, ix1:ix2, :] elem_vecs.append(region.reshape(-1, hidden).mean(dim=0)) outputs.append(torch.stack(elem_vecs, dim=0)) return torch.stack(outputs, dim=0) def build_memory(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: use_activation_checkpointing = self.training and bool(getattr(self.args, "activation_checkpointing", False)) bsz = batch["element_mask"].size(0) if self.vision is not None: pixel_values = batch["pixel_values"] visual = self.encode_vision(pixel_values) full_visual = visual[:, 0, :, :] visual_tokens = self.reduce_visual_tokens(visual) else: full_visual = None visual_tokens = batch["element_mask"].new_zeros((bsz, 0, self.hidden_size), dtype=self.decoder.get_input_embeddings().weight.dtype) elem_ids = batch["element_input_ids"] elem_mask_tok = batch["element_attention_mask"] bsz, max_elements, elem_len = elem_ids.shape elem_text = self.mean_embed(elem_ids.view(bsz * max_elements, elem_len), elem_mask_tok.view(bsz * max_elements, elem_len)) elem_text = self.elem_text_proj(elem_text.view(bsz, max_elements, -1)) bbox = batch["bboxes"] wh = (bbox[..., 2:] - bbox[..., :2]).clamp_min(0) bbox_feat = torch.cat([bbox, wh], dim=-1) elem_tokens = ( elem_text + self.type_emb(batch["type_ids"]) + self.source_emb(batch["source_ids"]) + self.loc_emb(batch["loc_ids"]) + self.bbox_proj(bbox_feat) + batch["confs"] * 0.1 + batch["action_flags"] * 0.1 ) if self.vision is not None and self.model_variant in {"full", "late_fusion"}: roi = self.roi_proj(self.roi_pool(full_visual, bbox)) elem_tokens = elem_tokens + roi elem_key_padding = ~batch["element_mask"] if use_activation_checkpointing: elem_tokens = checkpoint( lambda tokens, padding: self.layout_encoder(tokens, src_key_padding_mask=padding), elem_tokens, elem_key_padding, use_reentrant=False, ) else: elem_tokens = self.layout_encoder(elem_tokens, src_key_padding_mask=elem_key_padding) head_elem_tokens = elem_tokens decoder_elem_tokens = elem_tokens if bool(getattr(self.args, "function_signal_to_decoder", False)): function_signal = torch.sigmoid(self.ui_function_head(head_elem_tokens)).detach() function_signal = function_signal.masked_fill(elem_key_padding.unsqueeze(-1), 0.0) signal_scale = float(getattr(self.args, "function_signal_scale", 1.0) or 1.0) decoder_elem_tokens = decoder_elem_tokens + signal_scale * self.function_signal_proj(function_signal) if bool(getattr(self.args, "search_signal_to_decoder", False)): search_signal = torch.sigmoid(self.search_function_head(head_elem_tokens)).detach() search_signal = search_signal.masked_fill(elem_key_padding.unsqueeze(-1), 0.0) search_signal_scale = float(getattr(self.args, "search_signal_scale", 1.0) or 1.0) decoder_elem_tokens = decoder_elem_tokens + search_signal_scale * self.search_signal_proj(search_signal) context_mode = str(getattr(self.args, "context_mode", "mean") or "mean").lower() direct_context_tokens = None if context_mode in {"tokens_encoder", "tokens_direct_encoder"}: context_encoded = self.decoder.get_encoder()(input_ids=batch["context_input_ids"], attention_mask=batch["context_attention_mask"], return_dict=True) direct_context_tokens = context_encoded.last_hidden_state context_tokens = self.context_proj(context_encoded.last_hidden_state) context_padding = ~batch["context_attention_mask"].bool() elif context_mode in {"tokens", "tokens_direct"}: context_emb = self.decoder.get_input_embeddings()(batch["context_input_ids"]) direct_context_tokens = context_emb context_tokens = self.context_proj(context_emb) context_padding = ~batch["context_attention_mask"].bool() else: context = self.mean_embed(batch["context_input_ids"], batch["context_attention_mask"]) context_tokens = self.context_proj(context).unsqueeze(1) context_padding = torch.zeros(bsz, 1, dtype=torch.bool, device=context_tokens.device) context_len = context_tokens.size(1) visual_start = -1 visual_len = 0 elem_start = -1 elem_len_for_fusion = 0 if self.model_variant == "annotation_only": elem_start = context_len elem_len_for_fusion = decoder_elem_tokens.size(1) fusion_input = torch.cat([context_tokens, decoder_elem_tokens], dim=1) fusion_padding = torch.cat([context_padding, elem_key_padding], dim=1) elif self.model_variant == "image_only": if self.vision is None: raise ValueError("image_only requires vision; do not set disable_vision=true") visual_start = 0 visual_len = visual_tokens.size(1) fusion_input = visual_tokens fusion_padding = torch.zeros(bsz, fusion_input.size(1), dtype=torch.bool, device=fusion_input.device) elif self.model_variant == "late_fusion": visual_summary = visual_tokens.mean(dim=1, keepdim=True) visual_start = context_len visual_len = 1 elem_start = context_len + 1 elem_len_for_fusion = decoder_elem_tokens.size(1) fusion_input = torch.cat([context_tokens, visual_summary, decoder_elem_tokens], dim=1) fusion_padding = torch.cat( [context_padding, torch.zeros(bsz, 1, dtype=torch.bool, device=elem_key_padding.device), elem_key_padding], dim=1, ) else: visual_start = context_len visual_len = visual_tokens.size(1) elem_start = context_len + visual_len elem_len_for_fusion = decoder_elem_tokens.size(1) fusion_input = torch.cat([context_tokens, visual_tokens, decoder_elem_tokens], dim=1) fusion_padding = torch.cat( [ context_padding, torch.zeros(bsz, visual_tokens.size(1), dtype=torch.bool, device=elem_key_padding.device), elem_key_padding, ], dim=1, ) if use_activation_checkpointing: fused = checkpoint( lambda tokens, padding: self.fusion(tokens, src_key_padding_mask=padding), fusion_input, fusion_padding, use_reentrant=False, ) else: fused = self.fusion(fusion_input, src_key_padding_mask=fusion_padding) if elem_start >= 0 and elem_len_for_fusion > 0: head_elem_tokens = fused[:, elem_start : elem_start + elem_len_for_fusion, :] if use_activation_checkpointing: pooled = checkpoint( lambda tokens, padding: self.pooler(tokens, key_padding_mask=padding), fused, fusion_padding, use_reentrant=False, ) else: pooled = self.pooler(fused, key_padding_mask=fusion_padding) pooled_padding = torch.zeros(bsz, pooled.size(1), dtype=torch.bool, device=pooled.device) memory_parts = [] memory_padding_parts = [] if context_mode in {"tokens_direct", "tokens_direct_encoder"} and self.model_variant != "image_only": if bool(getattr(self.args, "direct_context_passthrough", False)) and direct_context_tokens is not None: memory_parts.append(direct_context_tokens) else: memory_parts.append(fused[:, :context_len, :]) memory_padding_parts.append(context_padding) if bool(getattr(self.args, "direct_visual_tokens", False)) and visual_start >= 0 and visual_len > 0: visual_scale = float(getattr(self.args, "visual_memory_scale", 1.0) or 1.0) memory_parts.append(fused[:, visual_start : visual_start + visual_len, :] * visual_scale) memory_padding_parts.append(torch.zeros(bsz, visual_len, dtype=torch.bool, device=pooled.device)) if bool(getattr(self.args, "direct_element_tokens", False)) and elem_start >= 0 and elem_len_for_fusion > 0: element_scale = float(getattr(self.args, "element_memory_scale", 1.0) or 1.0) memory_parts.append(fused[:, elem_start : elem_start + elem_len_for_fusion, :] * element_scale) memory_padding_parts.append(elem_key_padding) if bool(getattr(self.args, "include_pooled_memory", True)): pooled_scale = float(getattr(self.args, "pooled_memory_scale", 1.0) or 1.0) memory_parts.append(pooled * pooled_scale) memory_padding_parts.append(pooled_padding) if not memory_parts: pooled_scale = float(getattr(self.args, "pooled_memory_scale", 1.0) or 1.0) memory_parts.append(pooled * pooled_scale) memory_padding_parts.append(pooled_padding) memory = torch.cat(memory_parts, dim=1) memory_scale = float(getattr(self.args, "decoder_memory_scale", 1.0) or 1.0) if memory_scale != 1.0: memory = memory * memory_scale memory_padding = torch.cat(memory_padding_parts, dim=1) memory_attention_mask = (~memory_padding).long() return memory, memory_attention_mask, head_elem_tokens, elem_key_padding def forward(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]: if self.use_native_context_forward(): labels = batch["labels"] decoder_out = self.decoder( input_ids=batch["context_input_ids"], attention_mask=batch["context_attention_mask"], labels=labels, use_cache=False, ) gen_loss = decoder_out.loss element_shape = batch["element_mask"].shape empty_logits = labels.new_full(element_shape, -20.0, dtype=torch.float32) section_logits = labels.new_zeros((labels.size(0), int(batch["section_labels"].shape[-1])), dtype=torch.float32) zero_loss = gen_loss.detach().new_zeros(()) return { "loss": gen_loss, "generation_loss": gen_loss.detach(), "evidence_loss": zero_loss, "ui_function_loss": zero_loss, "search_function_loss": zero_loss, "section_loss": zero_loss, "numeric_loss": zero_loss, "evidence_logits": empty_logits, "ui_function_logits": empty_logits, "search_function_logits": empty_logits, "section_logits": section_logits, } memory, memory_attention_mask, elem_tokens, elem_key_padding = self.build_memory(batch) encoder_outputs = BaseModelOutput(last_hidden_state=memory) labels = batch["labels"] if hasattr(self.decoder, "prepare_decoder_input_ids_from_labels"): decoder_input_ids = self.decoder.prepare_decoder_input_ids_from_labels(labels=labels) elif hasattr(self.decoder, "_shift_right"): decoder_input_ids = self.decoder._shift_right(labels) else: pad_token_id = int(getattr(self.decoder.config, "pad_token_id", 0) or 0) start_token_id = int(getattr(self.decoder.config, "decoder_start_token_id", pad_token_id) or pad_token_id) decoder_input_ids = labels.new_full(labels.shape, pad_token_id) decoder_input_ids[:, 0] = start_token_id decoder_input_ids[:, 1:] = labels[:, :-1].masked_fill(labels[:, :-1] == -100, pad_token_id) decoder_out = self.decoder( encoder_outputs=encoder_outputs, attention_mask=memory_attention_mask, decoder_input_ids=decoder_input_ids, use_cache=False, ) logits = decoder_out.logits flat_logits = logits.reshape(-1, logits.size(-1)) flat_labels = labels.reshape(-1) valid_token_count = flat_labels.ne(-100).sum().clamp_min(1) chunk_size = max(1, int(getattr(self.args, "generation_loss_chunk_size", 32) or 32)) gen_loss_sum = logits.new_zeros((), dtype=torch.float32) for start in range(0, flat_logits.size(0), chunk_size): end = min(start + chunk_size, flat_logits.size(0)) gen_loss_sum = gen_loss_sum + F.cross_entropy( flat_logits[start:end].float(), flat_labels[start:end], ignore_index=-100, reduction="sum", ) gen_loss = gen_loss_sum / valid_token_count evidence_logits = self.evidence_head(elem_tokens).squeeze(-1) ui_function_logits = self.ui_function_head(elem_tokens).squeeze(-1) search_function_logits = self.search_function_head(elem_tokens).squeeze(-1) numeric_logits = self.numeric_head(elem_tokens).squeeze(-1) valid = (~elem_key_padding).float() evidence_loss = F.binary_cross_entropy_with_logits( evidence_logits, batch["evidence_labels"], reduction="none" ) evidence_loss = (evidence_loss * valid).sum() / valid.sum().clamp_min(1.0) ui_function_loss = F.binary_cross_entropy_with_logits( ui_function_logits, batch["ui_function_labels"], reduction="none" ) ui_function_loss = (ui_function_loss * valid).sum() / valid.sum().clamp_min(1.0) search_pos_weight = float(getattr(self.args, "search_function_pos_weight", 1.0) or 1.0) search_loss_kwargs: Dict[str, Any] = {"reduction": "none"} if search_pos_weight != 1.0: search_loss_kwargs["pos_weight"] = torch.tensor(search_pos_weight, device=search_function_logits.device) search_function_loss = F.binary_cross_entropy_with_logits( search_function_logits, batch["search_function_labels"], **search_loss_kwargs ) search_function_loss = (search_function_loss * valid).sum() / valid.sum().clamp_min(1.0) numeric_loss = F.binary_cross_entropy_with_logits( numeric_logits, batch["numeric_labels"], reduction="none" ) numeric_loss = (numeric_loss * valid).sum() / valid.sum().clamp_min(1.0) memory_mask = memory_attention_mask.unsqueeze(-1).float() memory_summary = (memory * memory_mask).sum(dim=1) / memory_mask.sum(dim=1).clamp_min(1.0) section_logits = self.section_head(memory_summary) section_loss = F.binary_cross_entropy_with_logits(section_logits, batch["section_labels"]) total = ( gen_loss + self.args.evidence_loss_weight * evidence_loss + float(getattr(self.args, "ui_function_loss_weight", 0.0) or 0.0) * ui_function_loss + float(getattr(self.args, "search_function_loss_weight", 0.0) or 0.0) * search_function_loss + self.args.section_loss_weight * section_loss + self.args.numeric_loss_weight * numeric_loss ) return { "loss": total, "generation_loss": gen_loss.detach(), "evidence_loss": evidence_loss.detach(), "ui_function_loss": ui_function_loss.detach(), "search_function_loss": search_function_loss.detach(), "section_loss": section_loss.detach(), "numeric_loss": numeric_loss.detach(), "evidence_logits": evidence_logits.detach(), "ui_function_logits": ui_function_logits.detach(), "search_function_logits": search_function_logits.detach(), "section_logits": section_logits.detach(), } @torch.no_grad() def generate_text(self, batch: Dict[str, torch.Tensor], tokenizer, num_beams: int = 4, max_new_tokens: int = 384) -> List[str]: # 评估时也按训练 amp_dtype 走 autocast,避免 mt5-large 在 valid 阶段以 fp32 # 跑生成时显存峰值远高于训练而 OOM,同时也明显加快评估速度。 amp_dtype_name = str(getattr(self.args, "amp_dtype", "auto") or "auto").lower() if amp_dtype_name == "auto": amp_dtype_name = "fp16" if bool(getattr(self.args, "fp16", False)) else "fp32" device_is_cuda = batch["pixel_values"].is_cuda if torch.is_tensor(batch.get("pixel_values")) else False amp_enabled = device_is_cuda and amp_dtype_name in {"fp16", "bf16"} amp_dtype = torch.float16 if amp_dtype_name == "fp16" else torch.bfloat16 with torch.amp.autocast("cuda", enabled=amp_enabled, dtype=amp_dtype): generation_kwargs = build_generation_kwargs(self.args, tokenizer) if self.use_native_context_forward(): generated = self.decoder.generate( input_ids=batch["context_input_ids"], attention_mask=batch["context_attention_mask"], num_beams=num_beams, max_new_tokens=max_new_tokens, **generation_kwargs, ) else: memory, memory_attention_mask, _, _ = self.build_memory(batch) encoder_outputs = BaseModelOutput(last_hidden_state=memory) generated = self.decoder.generate( encoder_outputs=encoder_outputs, attention_mask=memory_attention_mask, num_beams=num_beams, max_new_tokens=max_new_tokens, **generation_kwargs, ) return tokenizer.batch_decode(generated, skip_special_tokens=True) def move_batch(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]: out = {} for key, value in batch.items(): if key == "rows": out[key] = value elif torch.is_tensor(value): out[key] = value.to(device, non_blocking=True) else: out[key] = value return out def trainable_parameter_stats(model: nn.Module) -> Dict[str, Any]: stats = { "trainable_params_total": 0, "trainable_params_vision": 0, "trainable_params_decoder": 0, "trainable_params_other": 0, } for name, param in model.named_parameters(): if not param.requires_grad: continue count = int(param.numel()) stats["trainable_params_total"] += count if name.startswith("vision."): stats["trainable_params_vision"] += count elif name.startswith("decoder."): stats["trainable_params_decoder"] += count else: stats["trainable_params_other"] += count stats["vision_trainable"] = stats["trainable_params_vision"] > 0 return stats def reduce_parallel_losses(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: for key in [ "loss", "generation_loss", "evidence_loss", "ui_function_loss", "search_function_loss", "section_loss", "numeric_loss", ]: value = output.get(key) if torch.is_tensor(value) and value.ndim > 0: output[key] = value.mean() return output def build_optimizer(model: nn.Module, args: argparse.Namespace) -> torch.optim.Optimizer: decoder_params = [] vision_params = [] ui_function_params = [] other_params = [] lr_ui_function_head = float(getattr(args, "lr_ui_function_head", 0.0) or 0.0) for name, param in model.named_parameters(): if not param.requires_grad: continue if lr_ui_function_head > 0 and ( name.startswith("ui_function_head.") or name.startswith("search_function_head.") or name.startswith("function_signal_proj.") or name.startswith("search_signal_proj.") ): ui_function_params.append(param) elif name.startswith("decoder."): decoder_params.append(param) elif name.startswith("vision."): vision_params.append(param) else: other_params.append(param) groups = [ {"params": other_params, "lr": args.lr_new}, {"params": vision_params, "lr": args.lr_fusion}, {"params": decoder_params, "lr": args.lr_decoder}, ] if ui_function_params: groups.append({"params": ui_function_params, "lr": lr_ui_function_head}) optimizer_name = str(getattr(args, "optimizer_name", "adamw") or "adamw").lower() if optimizer_name == "adamw": return torch.optim.AdamW(groups, weight_decay=args.weight_decay) if optimizer_name == "adafactor": return Adafactor( groups, scale_parameter=False, relative_step=False, warmup_init=False, weight_decay=args.weight_decay, ) raise ValueError("optimizer_name must be one of: adamw, adafactor") def clip_gradients(model: nn.Module, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> torch.Tensor: max_grad_norm = float(getattr(args, "max_grad_norm", 1.0) or 0.0) first_param = next(model.parameters()) if max_grad_norm <= 0: return first_param.detach().new_tensor(0.0) strategy = str(getattr(args, "grad_clip_strategy", "global") or "global").lower() if strategy == "per_group": norms: List[torch.Tensor] = [] for group in optimizer.param_groups: params = [param for param in group["params"] if param.grad is not None] if not params: continue group_norm = torch.nn.utils.clip_grad_norm_(params, max_grad_norm, foreach=False) norms.append(group_norm.detach().to(device=first_param.device, dtype=torch.float32)) if not norms: return first_param.detach().new_tensor(0.0) return torch.stack(norms).max() if strategy != "global": raise ValueError("grad_clip_strategy must be one of: global, per_group") return torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm, foreach=False) def get_lr_schedule(optimizer, total_steps: int, warmup_ratio: float, scheduler_type: str = "linear"): warmup = int(total_steps * warmup_ratio) scheduler_type = str(scheduler_type or "linear").lower() def lr_lambda(step: int) -> float: if step < warmup: return max(1e-8, step / max(1, warmup)) progress = (step - warmup) / max(1, total_steps - warmup) if scheduler_type == "cosine": return max(0.0, 0.5 * (1.0 + math.cos(math.pi * min(1.0, progress)))) return max(0.0, 1.0 - progress) return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) def save_checkpoint( output_dir: Path, name: str, model: nn.Module, tokenizer, image_processor, args: argparse.Namespace, metrics: Dict[str, Any], ) -> None: ckpt_dir = output_dir / name ckpt_dir.mkdir(parents=True, exist_ok=True) raw_model = unwrap_parallel_model(model) torch.save(raw_model.state_dict(), ckpt_dir / "pytorch_model.bin") write_json(ckpt_dir / "rich_config.json", vars(args)) write_json(ckpt_dir / "metrics.json", metrics) tokenizer_dir = ckpt_dir / "decoder_tokenizer" tokenizer.save_pretrained(tokenizer_dir) vocab_file = getattr(tokenizer, "vocab_file", None) if vocab_file and Path(vocab_file).exists() and not (tokenizer_dir / "spiece.model").exists(): shutil.copyfile(vocab_file, tokenizer_dir / "spiece.model") decoder_model_dir = Path(str(getattr(args, "decoder_model", "") or "")) if decoder_model_dir.is_dir(): for tokenizer_asset in ("spiece.model", "config.json"): source_asset = decoder_model_dir / tokenizer_asset target_asset = tokenizer_dir / tokenizer_asset if source_asset.exists() and not target_asset.exists(): shutil.copyfile(source_asset, target_asset) image_processor.save_pretrained(ckpt_dir / "image_processor") def resize_checkpoint_tensor(value: torch.Tensor, target: torch.Tensor) -> Optional[torch.Tensor]: if value.ndim != target.ndim or value.ndim == 0: return None resized = value.detach().cpu() for dim, target_size in enumerate(target.shape): current_size = resized.shape[dim] if current_size == target_size: continue if current_size <= 0 or target_size <= 0: return None if current_size > target_size: resized = resized.narrow(dim, 0, target_size) continue repeats = [1] * resized.ndim repeats[dim] = math.ceil(target_size / current_size) resized = resized.repeat(*repeats).narrow(dim, 0, target_size) return resized.to(dtype=target.dtype) def load_compatible_model_state( model: nn.Module, state: Dict[str, torch.Tensor], allow_missing_prefixes: Tuple[str, ...] = (), resize_mismatched_non_decoder: bool = False, ) -> Tuple[List[str], List[str], List[str], List[str]]: model_state = model.state_dict() compatible_state: Dict[str, torch.Tensor] = {} resized_incompatible: List[str] = [] skipped_incompatible: List[str] = [] skipped_unexpected: List[str] = [] for key, value in state.items(): if key not in model_state: skipped_unexpected.append(key) continue if tuple(model_state[key].shape) != tuple(value.shape): if resize_mismatched_non_decoder and not key.startswith("decoder."): resized_value = resize_checkpoint_tensor(value, model_state[key]) if resized_value is not None and tuple(resized_value.shape) == tuple(model_state[key].shape): compatible_state[key] = resized_value resized_incompatible.append(key) continue skipped_incompatible.append(key) continue compatible_state[key] = value load_result = model.load_state_dict(compatible_state, strict=False) allowed_missing = [ key for key in load_result.missing_keys if key.startswith("ui_function_head.") or key.startswith("search_function_head.") or key.startswith("function_signal_proj.") or key.startswith("search_signal_proj.") or key in skipped_incompatible or key.startswith(allow_missing_prefixes) ] bad_missing = [key for key in load_result.missing_keys if key not in allowed_missing] if bad_missing or load_result.unexpected_keys or skipped_unexpected: raise RuntimeError( f"Checkpoint mismatch. missing={bad_missing}, unexpected={list(load_result.unexpected_keys) + skipped_unexpected}" ) return list(load_result.missing_keys), list(load_result.unexpected_keys), skipped_incompatible, resized_incompatible def load_rich_checkpoint(checkpoint: str, device: torch.device) -> Tuple[RichGroundedModel, Any, Any, argparse.Namespace]: ckpt_dir = Path(checkpoint) config = json.loads((ckpt_dir / "rich_config.json").read_text(encoding="utf-8")) merged_config = dict(DEFAULT_CONFIG) merged_config.update(config) args = argparse.Namespace(**merged_config) tokenizer = load_seq2seq_tokenizer(str(ckpt_dir / "decoder_tokenizer"), str(args.decoder_model)) image_processor = AutoImageProcessor.from_pretrained(ckpt_dir / "image_processor") model = RichGroundedModel(args) state = torch.load(ckpt_dir / "pytorch_model.bin", map_location="cpu") load_compatible_model_state(model, state) model.to(device) model.eval() return model, tokenizer, image_processor, args def get_eval_max_new_tokens(args: argparse.Namespace) -> int: return int(getattr(args, "eval_max_new_tokens", getattr(args, "max_target_tokens", 384))) def build_generation_kwargs(args: argparse.Namespace, tokenizer) -> Dict[str, Any]: kwargs: Dict[str, Any] = {} no_repeat_ngram_size = int(getattr(args, "generation_no_repeat_ngram_size", 0) or 0) if no_repeat_ngram_size > 0: kwargs["no_repeat_ngram_size"] = no_repeat_ngram_size repetition_penalty = float(getattr(args, "generation_repetition_penalty", 1.0) or 1.0) if repetition_penalty != 1.0: kwargs["repetition_penalty"] = repetition_penalty min_new_tokens = int(getattr(args, "generation_min_new_tokens", 0) or 0) if min_new_tokens > 0: kwargs["min_new_tokens"] = min_new_tokens bad_words_ids = [] if bool(getattr(args, "generation_block_extra_ids", False)): vocab_len = len(tokenizer) for token_id in range(max(0, vocab_len - 256), vocab_len): token = tokenizer.convert_ids_to_tokens(token_id) if isinstance(token, str) and " Dict[str, Any]: raw_model = unwrap_parallel_model(model) raw_model.eval() losses: List[float] = [] rouges: List[float] = [] json_valid = 0 generation_json_valid = 0 generation_json_strict_valid = 0 generation_json_repair_applied = 0 total = 0 evidence_precisions: List[float] = [] ui_function_tp = 0 ui_function_fp = 0 ui_function_fn = 0 ui_function_pred_positive = 0 ui_function_ref_positive = 0 ui_function_valid_elements = 0 generation_char_lengths: List[int] = [] pred_summary_char_lengths: List[int] = [] ref_summary_char_lengths: List[int] = [] empty_summary_count = 0 extra_id_count = 0 title_prefix_count = 0 natural_title_prefix_stripped_count = 0 whitespace_only_count = 0 search_function_tp = 0 search_function_fp = 0 search_function_fn = 0 search_function_pred_positive = 0 search_function_ref_positive = 0 search_tp = 0 search_fp = 0 search_fn = 0 pred_search_count = 0 ref_search_count = 0 pred_function_count = 0 ref_function_count = 0 pred_bare_search_count = 0 predictions: List[Dict[str, Any]] = [] eval_max_new_tokens = get_eval_max_new_tokens(args) context_summary_repair = bool(getattr(args, "context_summary_repair", False)) structured_mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower() target_schema = str(getattr(args, "target_schema", "zh") or "zh") summary_output_mode = target_schema_is_summary(target_schema) natural_output_mode = target_schema_is_natural_text(target_schema) function_metric_threshold = 0.5 search_metric_threshold = 0.5 if structured_mode == "heads": function_metric_threshold = float(getattr(args, "structured_function_threshold", 0.5) or 0.5) search_metric_threshold = float(getattr(args, "structured_search_threshold", function_metric_threshold) or function_metric_threshold) repair_count = 0 # 诊断指标:记录预测摘要是否与“这是一个{app}界面...当前任务语境是:{instruction}。”模板完全一致。 # 在 context_mode=tokens_direct 下 decoder 能直接看到 app+instruction token,只要模型能复制这些 token 到 # 模板位置,就能不看屏幕拿到高 ROUGE。template_match_rate 偏高表示训练信号实际是“拼模板”而非屏幕理解。 template_exact_match = 0 template_app_in_pred = 0 template_instruction_in_pred = 0 for batch in tqdm(loader, desc="valid", disable=not is_main(rank)): rows = batch["rows"] batch = move_batch(batch, device) out = raw_model(**{k: v for k, v in batch.items() if k != "rows"}) losses.append(float(out["loss"].detach().cpu())) valid_elements = batch["element_mask"].bool() function_labels = batch["ui_function_labels"] > 0.5 function_preds = torch.sigmoid(out["ui_function_logits"]) >= function_metric_threshold search_function_labels = batch["search_function_labels"] > 0.5 search_function_preds = torch.sigmoid(out["search_function_logits"]) >= search_metric_threshold evidence_scores_batch = torch.sigmoid(out["evidence_logits"]).detach().cpu() function_scores_batch = torch.sigmoid(out["ui_function_logits"]).detach().cpu() search_scores_batch = torch.sigmoid(out["search_function_logits"]).detach().cpu() ui_function_tp += int((function_preds & function_labels & valid_elements).sum().detach().cpu()) ui_function_fp += int((function_preds & ~function_labels & valid_elements).sum().detach().cpu()) ui_function_fn += int((~function_preds & function_labels & valid_elements).sum().detach().cpu()) ui_function_pred_positive += int((function_preds & valid_elements).sum().detach().cpu()) ui_function_ref_positive += int((function_labels & valid_elements).sum().detach().cpu()) search_function_tp += int((search_function_preds & search_function_labels & valid_elements).sum().detach().cpu()) search_function_fp += int((search_function_preds & ~search_function_labels & valid_elements).sum().detach().cpu()) search_function_fn += int((~search_function_preds & search_function_labels & valid_elements).sum().detach().cpu()) search_function_pred_positive += int((search_function_preds & valid_elements).sum().detach().cpu()) search_function_ref_positive += int((search_function_labels & valid_elements).sum().detach().cpu()) ui_function_valid_elements += int(valid_elements.sum().detach().cpu()) texts = raw_model.generate_text(batch, tokenizer, num_beams=args.num_beams, max_new_tokens=eval_max_new_tokens) for row_idx, (row, text) in enumerate(zip(rows, texts)): generation_char_lengths.append(len(text)) if text and not text.strip(): whitespace_only_count += 1 if " str: return str(getattr(args, "model_selection_metric", "rich_quality_score") or "rich_quality_score") def selection_metric_value(metrics: Dict[str, Any], args: argparse.Namespace) -> float: metric_name = selection_metric_name(args) value = metrics.get(metric_name) if value is None: value = metrics.get("rich_quality_score", 0.0) return float(value) def metric_is_better(current: float, best: float, args: argparse.Namespace) -> bool: mode = str(getattr(args, "model_selection_mode", "max") or "max").lower() min_delta = float(getattr(args, "early_stopping_min_delta", 0.0) or 0.0) if mode == "min": return current < best - min_delta return current > best + min_delta def train(args: argparse.Namespace) -> None: set_seed(args.seed) distributed, rank, world_size, local_rank = init_distributed() device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") if device.type == "cuda": cuda_memory_fraction = float(getattr(args, "cuda_memory_fraction", 0.0) or 0.0) if cuda_memory_fraction < 0.0 or cuda_memory_fraction > 1.0: raise ValueError("cuda_memory_fraction must be in [0, 1]. Use 0 to disable the limit.") if cuda_memory_fraction > 0.0: torch.cuda.set_per_process_memory_fraction(cuda_memory_fraction, device=device) if is_main(rank): total_gb = torch.cuda.get_device_properties(device).total_memory / 1024**3 print(f"CUDA memory fraction limit: {cuda_memory_fraction:.3f} (~{total_gb * cuda_memory_fraction:.2f}GB of {total_gb:.2f}GB)") torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True if hasattr(torch, "set_float32_matmul_precision"): torch.set_float32_matmul_precision("high") output_dir = Path(args.output_dir) if is_main(rank): output_dir.mkdir(parents=True, exist_ok=True) write_json(output_dir / "config.json", vars(args)) max_train_samples = int(getattr(args, "max_train_samples", 0) or 0) train_dataset = RichScreenshotDataset(args.train_file, max_samples=max_train_samples, sample_seed=args.seed if max_train_samples else None) valid_dataset = RichScreenshotDataset(args.valid_file, max_samples=args.max_valid_samples) train_data_diagnostics = dataset_diagnostics(train_dataset.rows) valid_data_diagnostics = dataset_diagnostics(valid_dataset.rows) validate_dataset_for_training("train", train_data_diagnostics, args) validate_dataset_for_training("valid", valid_data_diagnostics, args) tokenizer = load_seq2seq_tokenizer(args.decoder_model) train_token_diagnostics = tokenizer_diagnostics(train_dataset.rows, tokenizer, args) valid_token_diagnostics = tokenizer_diagnostics(valid_dataset.rows, tokenizer, args) validate_token_lengths("train", train_token_diagnostics, args) validate_token_lengths("valid", valid_token_diagnostics, args) if is_main(rank): diagnostics_payload = { "train_file": args.train_file, "valid_file": args.valid_file, "strict_data_checks": bool(getattr(args, "strict_data_checks", True)), "train": train_data_diagnostics, "valid": valid_data_diagnostics, "tokenizer": args.decoder_model, "max_target_tokens": int(getattr(args, "max_target_tokens", 0) or 0), "eval_max_new_tokens": int(getattr(args, "eval_max_new_tokens", 0) or 0), "max_target_truncation_rate": float(getattr(args, "max_target_truncation_rate", 0.01) or 0.0), "train_token_lengths": train_token_diagnostics, "valid_token_lengths": valid_token_diagnostics, } write_json(output_dir / "data_diagnostics.json", diagnostics_payload) print("Data diagnostics:") print(json.dumps(diagnostics_payload, ensure_ascii=False, indent=2)) image_processor = AutoImageProcessor.from_pretrained(args.vision_model) train_collator = RichCollator(tokenizer, image_processor, args, is_training=True) valid_collator = RichCollator(tokenizer, image_processor, args, is_training=False) train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=args.seed) if distributed else None valid_sampler = DistributedSampler(valid_dataset, shuffle=False) if distributed else None train_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=train_sampler is None, sampler=train_sampler, collate_fn=train_collator, num_workers=args.num_workers, pin_memory=torch.cuda.is_available(), ) valid_loader = DataLoader( valid_dataset, batch_size=args.eval_batch_size if getattr(args, "eval_batch_size", 0) else max(1, args.batch_size // 2), shuffle=False, sampler=valid_sampler, collate_fn=valid_collator, num_workers=args.num_workers, pin_memory=torch.cuda.is_available(), ) model = RichGroundedModel(args) init_checkpoint = str(getattr(args, "init_checkpoint", "") or "") if init_checkpoint: checkpoint_decoder_model = "" init_config_path = Path(init_checkpoint) / "rich_config.json" init_config: Dict[str, Any] = {} if init_config_path.exists(): init_config = json.loads(init_config_path.read_text(encoding="utf-8")) checkpoint_decoder_model = str(init_config.get("decoder_model", "") or "") allow_missing_prefixes_list: List[str] = [] if checkpoint_decoder_model and checkpoint_decoder_model != str(getattr(args, "decoder_model", "") or ""): allow_missing_prefixes_list.append("decoder.") if bool(init_config.get("disable_vision", False)) and not bool(getattr(args, "disable_vision", False)): allow_missing_prefixes_list.append("vision.") state = torch.load(Path(init_checkpoint) / "pytorch_model.bin", map_location="cpu") missing_keys, _, skipped_incompatible, resized_incompatible = load_compatible_model_state( model, state, allow_missing_prefixes=tuple(allow_missing_prefixes_list), resize_mismatched_non_decoder=bool(getattr(args, "init_resize_mismatched_non_decoder", False)), ) if is_main(rank): missing_preview = missing_keys[:12] skipped_preview = skipped_incompatible[:12] resized_preview = resized_incompatible[:12] print( f"Loaded init checkpoint: {init_checkpoint}; " f"missing_count={len(missing_keys)} preview={missing_preview}; " f"skipped_incompatible_count={len(skipped_incompatible)} preview={skipped_preview}; " f"resized_incompatible_count={len(resized_incompatible)} preview={resized_preview}" ) model = model.to(device) trainable_stats = trainable_parameter_stats(model) if is_main(rank): print( "Trainable parameters: " f"total={trainable_stats['trainable_params_total']:,}, " f"vision={trainable_stats['trainable_params_vision']:,}, " f"decoder={trainable_stats['trainable_params_decoder']:,}, " f"other={trainable_stats['trainable_params_other']:,}", flush=True, ) if distributed: model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True) elif bool(getattr(args, "data_parallel", False)) and device.type == "cuda" and torch.cuda.device_count() > 1: model = nn.DataParallel(model) if is_main(rank): print(f"Using nn.DataParallel on {torch.cuda.device_count()} CUDA devices.", flush=True) optimizer = build_optimizer(unwrap_parallel_model(model), args) scheduler_epochs = int(getattr(args, "scheduler_epochs", 0) or args.epochs) scheduler_epochs = max(int(args.epochs), scheduler_epochs) total_update_steps = math.ceil(len(train_loader) / args.grad_accum) * scheduler_epochs scheduler = get_lr_schedule(optimizer, total_update_steps, args.warmup_ratio, getattr(args, "lr_scheduler_type", "linear")) amp_dtype_name = str(getattr(args, "amp_dtype", "auto") or "auto").lower() if amp_dtype_name == "auto": amp_dtype_name = "fp16" if bool(getattr(args, "fp16", False)) else "fp32" if amp_dtype_name not in {"fp32", "fp16", "bf16"}: raise ValueError("amp_dtype must be one of: auto, fp32, fp16, bf16") amp_enabled = device.type == "cuda" and amp_dtype_name in {"fp16", "bf16"} amp_dtype = torch.float16 if amp_dtype_name == "fp16" else torch.bfloat16 scaler = torch.amp.GradScaler("cuda", enabled=amp_enabled and amp_dtype_name == "fp16") model_selection_mode = str(getattr(args, "model_selection_mode", "max") or "max").lower() best_score = math.inf if model_selection_mode == "min" else -math.inf epochs_without_improvement = 0 global_step = 0 train_start_time = time.time() recent_losses: deque[float] = deque(maxlen=100) optimizer.zero_grad(set_to_none=True) if device.type == "cuda": gc.collect() torch.cuda.empty_cache() torch.cuda.reset_peak_memory_stats(device) skip_optimizer_window = False skipped_nonfinite_windows = 0 for epoch in range(args.epochs): if train_sampler is not None: train_sampler.set_epoch(epoch) progress = tqdm(train_loader, desc=f"epoch {epoch + 1}/{args.epochs}", disable=not is_main(rank)) for step, batch in enumerate(progress, start=1): batch = move_batch(batch, device) with torch.amp.autocast("cuda", enabled=amp_enabled, dtype=amp_dtype): out = model(**{k: v for k, v in batch.items() if k != "rows"}) out = reduce_parallel_losses(out) loss = out["loss"] / args.grad_accum loss_is_finite = bool(torch.isfinite(loss.detach()).all().item()) if not loss_is_finite: skip_optimizer_window = True optimizer.zero_grad(set_to_none=True) if is_main(rank): append_jsonl( output_dir / "metrics.jsonl", { "event": "skip_nonfinite_microbatch", "epoch": epoch + 1, "micro_step": step, "loss": float(out["loss"].detach().cpu()), "generation_loss": float(out["generation_loss"].detach().cpu()), "evidence_loss": float(out["evidence_loss"].detach().cpu()), "ui_function_loss": float(out["ui_function_loss"].detach().cpu()), "search_function_loss": float(out["search_function_loss"].detach().cpu()), "section_loss": float(out["section_loss"].detach().cpu()), "numeric_loss": float(out["numeric_loss"].detach().cpu()), }, ) elif not skip_optimizer_window: scaler.scale(loss).backward() step_log = None if step % args.grad_accum == 0: if skip_optimizer_window: skipped_nonfinite_windows += 1 optimizer.zero_grad(set_to_none=True) skip_optimizer_window = False if is_main(rank): append_jsonl( output_dir / "metrics.jsonl", { "event": "skip_nonfinite_optimizer_window", "epoch": epoch + 1, "micro_step": step, "skipped_nonfinite_windows": skipped_nonfinite_windows, "lr": max(scheduler.get_last_lr()), }, ) else: lrs_used = list(scheduler.get_last_lr()) scaler.unscale_(optimizer) grad_norm = clip_gradients(model, optimizer, args) grad_is_finite = bool(torch.isfinite(grad_norm.detach()).all().item()) if torch.is_tensor(grad_norm) else math.isfinite(float(grad_norm)) if not grad_is_finite: skipped_nonfinite_windows += 1 optimizer.zero_grad(set_to_none=True) if scaler.is_enabled(): scaler.update() if is_main(rank): append_jsonl( output_dir / "metrics.jsonl", { "event": "skip_nonfinite_grad", "epoch": epoch + 1, "micro_step": step, "grad_norm": float(grad_norm.detach().cpu()) if torch.is_tensor(grad_norm) else float(grad_norm), "skipped_nonfinite_windows": skipped_nonfinite_windows, "lr": max(scheduler.get_last_lr()), }, ) else: scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) scheduler.step() global_step += 1 if is_main(rank): loss_value = float(out["loss"].detach().cpu()) recent_losses.append(loss_value) label_counts = (batch["labels"] != -100).sum(dim=1).detach().float().cpu() context_counts = batch["context_attention_mask"].sum(dim=1).detach().float().cpu() element_counts = batch["element_mask"].sum(dim=1).detach().float().cpu() loss_window = list(recent_losses) loss_window_20 = loss_window[-20:] elapsed_seconds = max(time.time() - train_start_time, 1e-6) next_lrs = list(scheduler.get_last_lr()) step_log = { "step": global_step, "epoch": epoch + 1, "loss": loss_value, "loss_ma20": float(np.mean(loss_window_20)) if loss_window_20 else loss_value, "loss_ma100": float(np.mean(loss_window)) if loss_window else loss_value, "generation_loss": float(out["generation_loss"].cpu()), "evidence_loss": float(out["evidence_loss"].cpu()), "ui_function_loss": float(out["ui_function_loss"].cpu()), "search_function_loss": float(out["search_function_loss"].cpu()), "section_loss": float(out["section_loss"].cpu()), "numeric_loss": float(out["numeric_loss"].cpu()), "lr": max(lrs_used) if lrs_used else 0.0, "lr_other": lrs_used[0] if len(lrs_used) > 0 else 0.0, "lr_vision": lrs_used[1] if len(lrs_used) > 1 else 0.0, "lr_decoder": lrs_used[2] if len(lrs_used) > 2 else 0.0, "lr_ui_function_head": lrs_used[3] if len(lrs_used) > 3 else 0.0, "lr_next": max(next_lrs) if next_lrs else 0.0, "lr_scheduler_type": str(getattr(args, "lr_scheduler_type", "linear") or "linear"), "grad_norm": float(grad_norm.detach().cpu()) if torch.is_tensor(grad_norm) else float(grad_norm), "skipped_nonfinite_windows": skipped_nonfinite_windows, "optimizer_steps_total": total_update_steps, "train_progress_pct": 100.0 * global_step / max(1, total_update_steps), "elapsed_seconds": elapsed_seconds, "optimizer_steps_per_sec": global_step / elapsed_seconds, "target_tokens_mean": float(label_counts.mean().item()) if label_counts.numel() else 0.0, "target_tokens_max": int(label_counts.max().item()) if label_counts.numel() else 0, "target_at_max_rate": float((label_counts >= int(args.max_target_tokens)).float().mean().item()) if label_counts.numel() else 0.0, "context_tokens_mean": float(context_counts.mean().item()) if context_counts.numel() else 0.0, "context_tokens_max": int(context_counts.max().item()) if context_counts.numel() else 0, "context_at_max_rate": float((context_counts >= int(args.max_context_tokens)).float().mean().item()) if context_counts.numel() else 0.0, "element_count_mean": float(element_counts.mean().item()) if element_counts.numel() else 0.0, "element_count_max": int(element_counts.max().item()) if element_counts.numel() else 0, "model_variant": str(getattr(args, "model_variant", "")), "vision_enabled": not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only", "native_context_forward": bool(getattr(args, "native_context_forward", False)), "freeze_decoder": bool(getattr(args, "freeze_decoder", False)), "image_size": int(getattr(args, "image_size", 0) or 0), "image_crops": int(batch["pixel_values"].shape[1]) if torch.is_tensor(batch.get("pixel_values")) and batch["pixel_values"].ndim >= 5 else 0, "max_visual_tokens": int(getattr(args, "max_visual_tokens", 0) or 0), "direct_visual_tokens": bool(getattr(args, "direct_visual_tokens", False)), "direct_element_tokens": bool(getattr(args, "direct_element_tokens", False)), "visual_memory_scale": float(getattr(args, "visual_memory_scale", 1.0) or 1.0), "element_memory_scale": float(getattr(args, "element_memory_scale", 1.0) or 1.0), "pooled_memory_scale": float(getattr(args, "pooled_memory_scale", 1.0) or 1.0), "cuda_memory_fraction": float(getattr(args, "cuda_memory_fraction", 0.0) or 0.0), "data_parallel": isinstance(model, nn.DataParallel), "distributed": bool(distributed), **trainable_stats, } del loss del out del batch empty_cache_steps = int(getattr(args, "cuda_empty_cache_steps", 0) or 0) if ( device.type == "cuda" and empty_cache_steps > 0 and step % args.grad_accum == 0 and global_step % empty_cache_steps == 0 ): gc.collect() torch.cuda.empty_cache() if step_log is not None and is_main(rank): mem = {} if torch.cuda.is_available(): mem = { "gpu_allocated_gb": round(torch.cuda.memory_allocated(device) / 1024**3, 3), "gpu_reserved_gb": round(torch.cuda.memory_reserved(device) / 1024**3, 3), "gpu_peak_allocated_gb": round(torch.cuda.max_memory_allocated(device) / 1024**3, 3), "gpu_peak_reserved_gb": round(torch.cuda.max_memory_reserved(device) / 1024**3, 3), } log = { **step_log, **mem, } append_jsonl(output_dir / "metrics.jsonl", log) if device.type == "cuda": torch.cuda.reset_peak_memory_stats(device) progress.set_postfix(loss=f"{log['loss']:.3f}", ma20=f"{log['loss_ma20']:.3f}", pct=f"{log['train_progress_pct']:.1f}%") if ( step_log is not None and args.save_checkpoints and args.save_every_steps and global_step > 0 and global_step % args.save_every_steps == 0 and is_main(rank) ): save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"step": global_step}) if step_log is not None and args.eval_every_steps and global_step > 0 and global_step % args.eval_every_steps == 0: if is_main(rank) and args.save_checkpoints: save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"step": global_step, "pre_eval": True}) metrics = evaluate(model, valid_loader, tokenizer, device, args, rank=rank) current_score = selection_metric_value(metrics, args) metrics["selection_metric"] = selection_metric_name(args) metrics["selection_score"] = current_score if metric_is_better(current_score, best_score, args): best_score = current_score if is_main(rank) and args.save_checkpoints: save_checkpoint(output_dir, "checkpoint-best", model, tokenizer, image_processor, args, metrics) if device.type == "cuda" and empty_cache_steps > 0: gc.collect() torch.cuda.empty_cache() if is_main(rank): append_jsonl(output_dir / "metrics.jsonl", {"step_eval": global_step, **metrics}) if is_main(rank) and args.save_checkpoints: save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"epoch": epoch + 1, "step": global_step, "pre_epoch_eval": True}) metrics = evaluate(model, valid_loader, tokenizer, device, args, rank=rank) current_score = selection_metric_value(metrics, args) metrics["selection_metric"] = selection_metric_name(args) metrics["selection_score"] = current_score improved = metric_is_better(current_score, best_score, args) if improved: best_score = current_score epochs_without_improvement = 0 else: epochs_without_improvement += 1 if is_main(rank): append_jsonl(output_dir / "metrics.jsonl", {"epoch_eval": epoch + 1, **metrics}) if args.save_checkpoints: save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, metrics) if args.save_checkpoints and improved: save_checkpoint(output_dir, "checkpoint-best", model, tokenizer, image_processor, args, metrics) patience = int(getattr(args, "early_stopping_patience", 0) or 0) if patience > 0: print( f"selection_metric={metrics['selection_metric']} score={current_score:.6f} " f"best={best_score:.6f} no_improve={epochs_without_improvement}/{patience}" ) patience = int(getattr(args, "early_stopping_patience", 0) or 0) if patience > 0 and epochs_without_improvement >= patience: if is_main(rank): print(f"Early stopping at epoch {epoch + 1}: no improvement for {patience} evals.") break if distributed: dist.destroy_process_group() def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Train rich CMGUI screenshot summarization models.") for key, value in DEFAULT_CONFIG.items(): arg_type = type(value) if isinstance(value, bool): parser.add_argument(f"--{key}", type=lambda x: str(x).lower() in {"1", "true", "yes"}, default=value) else: parser.add_argument(f"--{key}", type=arg_type, default=value) parser.add_argument("--bottleneck_queries", type=int, default=64) args = parser.parse_args() args.decoder_model = normalize_model_reference(args.decoder_model) args.vision_model = normalize_model_reference(args.vision_model) if args.model_variant not in {"annotation_only", "image_only", "late_fusion", "full"}: raise ValueError("model_variant must be one of annotation_only, image_only, late_fusion, full") args.context_mode = str(getattr(args, "context_mode", "mean") or "mean").lower() if args.context_mode not in {"mean", "tokens", "tokens_direct", "tokens_encoder", "tokens_direct_encoder"}: raise ValueError("context_mode must be one of: mean, tokens, tokens_direct, tokens_encoder, tokens_direct_encoder") args.lr_scheduler_type = str(getattr(args, "lr_scheduler_type", "linear") or "linear").lower() if args.lr_scheduler_type not in {"linear", "cosine"}: raise ValueError("lr_scheduler_type must be one of: linear, cosine") args.target_schema = str(getattr(args, "target_schema", "zh") or "zh").lower() if args.target_schema not in { "zh", "alias", "aliases", "en", "english", "summary", "summary_zh", "summary-only", "summary_only", "natural_zh", "rich_text_zh", "zh_text", "text_zh", "summary_visible_zh", "natural_summary_visible_zh", }: raise ValueError("target_schema must be one of: zh, aliases, summary_zh, natural_zh, summary_visible_zh") args.grad_clip_strategy = str(getattr(args, "grad_clip_strategy", "global") or "global").lower() if args.grad_clip_strategy not in {"global", "per_group"}: raise ValueError("grad_clip_strategy must be one of: global, per_group") args.structured_function_mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower() if args.structured_function_mode not in {"decoder", "heads"}: raise ValueError("structured_function_mode must be one of: decoder, heads") args.structured_evidence_mode = str(getattr(args, "structured_evidence_mode", "decoder") or "decoder").lower() if args.structured_evidence_mode not in {"decoder", "heads"}: raise ValueError("structured_evidence_mode must be one of: decoder, heads") return args if __name__ == "__main__": train(parse_args())