|
|
| """Train rich Chinese screenshot summarization models on CMGUI-style data."""
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import gc
|
| import json |
| import math |
| import os |
| import random |
| import re |
| import shutil |
| import time |
| from collections import deque
|
| from dataclasses import asdict, dataclass
|
| from pathlib import Path
|
| from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
| import numpy as np
|
| import torch
|
| import torch.distributed as dist
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from PIL import Image
|
| from torch.nn.parallel import DistributedDataParallel
|
| from torch.utils.checkpoint import checkpoint
|
| from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
| from tqdm import tqdm
|
| from transformers import Adafactor, AutoImageProcessor, AutoModel, AutoModelForSeq2SeqLM, AutoTokenizer, T5Tokenizer
|
| from transformers.modeling_outputs import BaseModelOutput
|
|
|
|
|
| DEFAULT_CONFIG = {
|
|
|
| "train_file": "data/rich_cmgui/processed/smoke_train_rich.jsonl",
|
| "valid_file": "data/rich_cmgui/processed/smoke_valid_rich.jsonl",
|
| "output_dir": "runs/rich_cmgui_20260502/rich_grounded_siglip2_mt5_v1",
|
| "init_checkpoint": "",
|
|
|
| "model_variant": "full",
|
| "vision_model": "models/siglip2-base-patch16-224",
|
| "decoder_model": "google/mt5-base",
|
|
|
| "image_size": 384,
|
| "num_vertical_crops": 3,
|
|
|
|
|
| "max_visual_tokens": 0,
|
| "max_elements": 80,
|
| "max_element_tokens": 16,
|
| "max_context_tokens": 64,
|
| "context_text_format": "rich", |
| "context_include_screen_text": False, |
| "context_screen_text_items": 32, |
| "context_screen_text_dropout_rate": 0.0, |
|
|
| "context_mode": "mean",
|
|
|
| "max_target_tokens": 384,
|
| "eval_max_new_tokens": 384,
|
|
|
| "batch_size": 4,
|
| "eval_batch_size": 0,
|
| "grad_accum": 8,
|
| "epochs": 6,
|
|
|
|
|
| "scheduler_epochs": 0,
|
| "lr_new": 1e-4,
|
| "lr_fusion": 5e-5,
|
| "lr_decoder": 1e-5, |
| "lr_ui_function_head": 0.0, |
| "weight_decay": 0.01, |
| "optimizer_name": "adamw", |
| "lr_scheduler_type": "linear", |
| "warmup_ratio": 0.05, |
| "fp16": True,
|
| "amp_dtype": "auto",
|
| "generation_loss_chunk_size": 32,
|
| "activation_checkpointing": False,
|
| "cuda_empty_cache_steps": 0,
|
| "cuda_memory_fraction": 0.0,
|
| "decoder_gradient_checkpointing": False, |
| "vision_gradient_checkpointing": False, |
| "freeze_decoder": False, |
| "freeze_vision": True, |
| "unfreeze_vision_last_ratio": 0.3, |
|
|
| "evidence_loss_weight": 0.2,
|
| "section_loss_weight": 0.1,
|
| "numeric_loss_weight": 0.1,
|
| "ui_function_loss_weight": 0.0,
|
| "search_function_loss_weight": 0.0,
|
| "search_function_pos_weight": 1.0,
|
|
|
| "save_every_steps": 1000,
|
| "save_checkpoints": True,
|
| "eval_every_steps": 0,
|
| "model_selection_metric": "rich_quality_score",
|
| "model_selection_mode": "max",
|
| "early_stopping_patience": 0,
|
| "early_stopping_min_delta": 0.0,
|
| "max_train_samples": 0,
|
| "max_valid_samples": 800,
|
| "num_beams": 4,
|
| "generation_no_repeat_ngram_size": 0,
|
| "generation_repetition_penalty": 1.0, |
| "generation_min_new_tokens": 0, |
| "generation_block_extra_ids": False, |
| "generation_block_title_prefix": False, |
| "generation_force_json_start": False, |
| "context_summary_repair": False,
|
| "canonicalize_targets": False,
|
| "target_schema": "zh",
|
| "task_intent_context": False,
|
| "drop_bare_search_functions": False,
|
| "structured_function_mode": "decoder",
|
| "structured_function_threshold": 0.5,
|
| "structured_search_threshold": 0.5,
|
| "structured_max_functions": 12,
|
| "structured_strict_search_candidates": False,
|
| "structured_evidence_mode": "decoder",
|
| "structured_evidence_threshold": 0.5,
|
| "structured_max_evidence": 8,
|
| "structured_evidence_fallback_top1": True,
|
|
|
| "direct_visual_tokens": False,
|
| "direct_element_tokens": False,
|
| "direct_context_passthrough": False,
|
| "include_pooled_memory": True,
|
| "native_context_forward": False,
|
| "disable_vision": False,
|
| "init_resize_mismatched_non_decoder": False,
|
| "grad_clip_strategy": "global",
|
| "max_grad_norm": 1.0,
|
| "function_signal_to_decoder": False,
|
| "function_signal_scale": 1.0,
|
| "search_signal_to_decoder": False,
|
| "search_signal_scale": 1.0,
|
| "visual_memory_scale": 1.0, |
| "element_memory_scale": 1.0, |
| "pooled_memory_scale": 1.0, |
| "decoder_memory_scale": 1.0, |
| "data_parallel": False, |
| "strict_data_checks": True, |
| "max_target_truncation_rate": 0.01, |
| "seed": 20260502, |
| "num_workers": 4, |
| } |
|
|
| SECTION_NAMES = ["visible_text", "interaction_data", "ui_functions", "key_ui_clues"]
|
|
|
|
|
| def read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
|
| with path.open("r", encoding="utf-8") as f:
|
| for line in f:
|
| line = line.strip()
|
| if line:
|
| yield json.loads(line)
|
|
|
|
|
| def write_json(path: Path, obj: Dict[str, Any]) -> None:
|
| path.parent.mkdir(parents=True, exist_ok=True)
|
| path.write_text(json.dumps(obj, ensure_ascii=False, indent=2), encoding="utf-8")
|
|
|
|
|
| def append_jsonl(path: Path, obj: Dict[str, Any]) -> None:
|
| path.parent.mkdir(parents=True, exist_ok=True)
|
| with path.open("a", encoding="utf-8", newline="\n") as f:
|
| f.write(json.dumps(obj, ensure_ascii=False) + "\n")
|
|
|
|
|
| def set_seed(seed: int) -> None:
|
| random.seed(seed)
|
| np.random.seed(seed)
|
| torch.manual_seed(seed)
|
| torch.cuda.manual_seed_all(seed)
|
|
|
|
|
| def normalize_model_reference(value: Any) -> str:
|
| normalized = str(value or "").replace("\\", "/")
|
| if "://" not in normalized:
|
| normalized = re.sub(r"/+", "/", normalized)
|
| return normalized
|
|
|
|
|
| def init_distributed() -> Tuple[bool, int, int, int]:
|
| if "RANK" not in os.environ or "WORLD_SIZE" not in os.environ:
|
| return False, 0, 1, 0
|
| rank = int(os.environ["RANK"])
|
| world_size = int(os.environ["WORLD_SIZE"])
|
| local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
| dist.init_process_group(backend="nccl")
|
| torch.cuda.set_device(local_rank)
|
| return True, rank, world_size, local_rank
|
|
|
|
|
| def is_main(rank: int) -> bool:
|
| return rank == 0
|
|
|
|
|
| def unwrap_parallel_model(model: nn.Module) -> nn.Module:
|
| if isinstance(model, (DistributedDataParallel, nn.DataParallel)):
|
| return model.module
|
| return model
|
|
|
|
|
| def safe_text(value: Any) -> str:
|
| if value is None:
|
| return ""
|
| return re.sub(r"\s+", " ", str(value)).strip()
|
|
|
|
|
| def target_to_text(target: Dict[str, Any], target_schema: str = "zh") -> str: |
| schema = str(target_schema or "zh").lower() |
| if target_schema_is_summary(schema): |
| return safe_text(target.get("summary_zh")) |
| if target_schema_is_summary_visible(schema): |
| return target_to_summary_visible_text(target) |
| if target_schema_is_natural_text(schema): |
| return target_to_natural_text(target) |
| if schema in {"alias", "aliases", "en", "english"}:
|
| payload = {
|
| "summary_zh": safe_text(target.get("summary_zh")),
|
| "visible_text": target.get("visible_text", [])[:16],
|
| "interaction_data": target.get("interaction_data", [])[:12],
|
| "ui_functions": target.get("ui_functions", [])[:12],
|
| "key_ui_clues": target.get("key_ui_clues", [])[:12],
|
| }
|
| else:
|
| payload = {
|
| "画面总结": safe_text(target.get("summary_zh")),
|
| "可见文字": target.get("visible_text", [])[:16],
|
| "互动数据": target.get("interaction_data", [])[:12],
|
| "功能入口": target.get("ui_functions", [])[:12],
|
| "关键证据": target.get("key_ui_clues", [])[:12],
|
| }
|
| return json.dumps(payload, ensure_ascii=False, separators=(",", ":"))
|
|
|
|
|
| def target_schema_is_summary(target_schema: str) -> bool:
|
| return str(target_schema or "").lower() in {"summary", "summary_zh", "summary-only", "summary_only"}
|
|
|
|
|
| def target_schema_is_natural_text(target_schema: str) -> bool: |
| return str(target_schema or "").lower() in { |
| "natural_zh", |
| "rich_text_zh", |
| "zh_text", |
| "text_zh", |
| "summary_visible_zh", |
| "natural_summary_visible_zh", |
| } |
|
|
|
|
| def target_schema_is_summary_visible(target_schema: str) -> bool: |
| return str(target_schema or "").lower() in {"summary_visible_zh", "natural_summary_visible_zh"} |
|
|
|
|
| def format_evidence_suffix(values: Any) -> str:
|
| evidence_ids = dedupe_texts(values or [], max_items=8)
|
| if not evidence_ids:
|
| return ""
|
| return f"(证据:{'、'.join(evidence_ids)})"
|
|
|
|
|
| def format_natural_entry(value: Any) -> str:
|
| if isinstance(value, dict):
|
| name = safe_text(value.get("name") or value.get("text") or value.get("label"))
|
| detail = safe_text(value.get("value"))
|
| if name and detail:
|
| name = f"{name}={detail}"
|
| elif detail and not name:
|
| name = detail
|
| return safe_text(name + format_evidence_suffix(value.get("evidence_ids", [])))
|
| return safe_text(value)
|
|
|
|
|
| def format_natural_list(values: Any, max_items: int = 16) -> str:
|
| if not isinstance(values, list):
|
| return "无"
|
| entries = []
|
| seen = set()
|
| for value in values[:max_items]:
|
| text = format_natural_entry(value)
|
| if text and text not in seen:
|
| entries.append(text)
|
| seen.add(text)
|
| return "、".join(entries) if entries else "无"
|
|
|
|
|
| def target_to_natural_text(target: Dict[str, Any]) -> str: |
| summary = safe_text(target.get("summary_zh")) |
| visible = format_natural_list(target.get("visible_text", []), max_items=16) |
| interactions = format_natural_list(target.get("interaction_data", []), max_items=12) |
| functions = format_natural_list(target.get("ui_functions", []), max_items=12) |
| evidence = format_natural_list(target.get("key_ui_clues", []), max_items=12)
|
| return "\n".join(
|
| [
|
| f"画面总结:{summary}",
|
| f"可见文字:{visible}",
|
| f"互动数据:{interactions}",
|
| f"功能入口:{functions}",
|
| f"关键证据:{evidence}",
|
| ] |
| ) |
|
|
|
|
| def target_to_summary_visible_text(target: Dict[str, Any]) -> str: |
| summary = safe_text(target.get("summary_zh")) |
| visible = format_natural_list(target.get("visible_text", []), max_items=16) |
| interactions = format_natural_list(target.get("interaction_data", []), max_items=12) |
| return "\n".join( |
| [ |
| f"画面总结:{summary}", |
| f"可见文字:{visible}", |
| f"互动数据:{interactions}", |
| ] |
| ) |
|
|
|
|
| JSONISH_OUTPUT_KEYS = ("画面总结", "可见文字", "互动数据", "功能入口", "关键证据")
|
| JSONISH_ALL_KEYS = JSONISH_OUTPUT_KEYS + ("name", "value", "evidence_ids")
|
|
|
|
|
| def normalize_evidence_id(value: Any) -> str:
|
| text = safe_text(value)
|
| match = re.fullmatch(r"[Ee]\s*((?:\d\s*){2,8})", text)
|
| if match:
|
| digits = re.sub(r"\s+", "", match.group(1))
|
| return f"E{digits}"
|
| return text
|
|
|
|
|
| def repair_generated_json_text(text: str) -> str:
|
| candidate = safe_text(text)
|
| start = candidate.find("{")
|
| if start >= 0:
|
| candidate = candidate[start:]
|
| candidate = re.sub(r"evidence\s*_\s*ids", "evidence_ids", candidate)
|
| candidate = re.sub(
|
| r"\b[Ee]\s*((?:\d\s*){2,8})\b",
|
| lambda match: "E" + re.sub(r"\s+", "", match.group(1)),
|
| candidate,
|
| )
|
| for key in JSONISH_ALL_KEYS:
|
| candidate = re.sub(
|
| r"([\{\[,])\s*\"?\s*" + re.escape(key) + r"\s*\"\s*:",
|
| lambda match, key=key: f'{match.group(1)}"{key}":',
|
| candidate,
|
| )
|
| candidate = re.sub(
|
| r"([\{\[,])\s*\"?\s*" + re.escape(key) + r"\s*:",
|
| lambda match, key=key: f'{match.group(1)}"{key}":',
|
| candidate,
|
| )
|
| return candidate
|
|
|
|
|
| def json_object_from_text(text: str) -> Tuple[Optional[Dict[str, Any]], bool]:
|
| text = text.strip()
|
| if not text:
|
| return None, False
|
| try:
|
| obj = json.loads(text)
|
| return (obj, True) if isinstance(obj, dict) else (None, False)
|
| except json.JSONDecodeError:
|
| start = text.find("{")
|
| end = text.rfind("}")
|
| if start >= 0 and end > start:
|
| try:
|
| obj = json.loads(text[start : end + 1])
|
| return (obj, True) if isinstance(obj, dict) else (None, False)
|
| except json.JSONDecodeError:
|
| pass
|
| return None, False
|
|
|
|
|
| def jsonish_field_payload(text: str, key: str, following_keys: Iterable[str]) -> str:
|
| match = re.search(r'"' + re.escape(key) + r'"\s*:', text)
|
| if not match:
|
| return ""
|
| tail = text[match.end() :]
|
| end = len(tail)
|
| for following_key in following_keys:
|
| next_match = re.search(r',?\s*"' + re.escape(following_key) + r'"\s*:', tail)
|
| if next_match:
|
| end = min(end, next_match.start())
|
| return tail[:end].strip().rstrip(",")
|
|
|
|
|
| def clean_jsonish_string(value: Any) -> str:
|
| text = safe_text(value).strip()
|
| while text.startswith('"'):
|
| text = text[1:].strip()
|
| while text.endswith('"'):
|
| text = text[:-1].strip()
|
| return safe_text(text.replace('\\"', '"'))
|
|
|
|
|
| def dedupe_texts(values: Iterable[Any], max_items: int = 64) -> List[str]:
|
| output: List[str] = []
|
| seen = set()
|
| for value in values:
|
| text = clean_jsonish_string(normalize_evidence_id(value))
|
| if not text or text in seen:
|
| continue
|
| output.append(text)
|
| seen.add(text)
|
| if len(output) >= max_items:
|
| break
|
| return output
|
|
|
|
|
| def jsonish_string_list(payload: str, max_items: int = 64) -> List[str]:
|
| body = payload.strip()
|
| if "[" in body:
|
| body = body[body.find("[") + 1 :]
|
| if "]" in body:
|
| body = body[: body.rfind("]")]
|
| values = [match.group(1) for match in re.finditer(r'"([^"\\]*(?:\\.[^"\\]*)*)"', body)]
|
| if not values:
|
| values = re.findall(r"\bE\d{2,8}\b", body)
|
| return dedupe_texts(values, max_items=max_items)
|
|
|
|
|
| def normalize_function_entries(values: Any) -> List[Any]:
|
| if not isinstance(values, list):
|
| return []
|
| output: List[Any] = []
|
| for value in values:
|
| if isinstance(value, dict):
|
| name = clean_jsonish_string(value.get("name"))
|
| evidence_ids = dedupe_texts(value.get("evidence_ids", []) or [], max_items=8)
|
| if name:
|
| output.append({"name": name, "evidence_ids": evidence_ids})
|
| else:
|
| name = clean_jsonish_string(value)
|
| if name:
|
| output.append(name)
|
| return output
|
|
|
|
|
| def jsonish_function_entries(payload: str) -> List[Any]:
|
| body = payload.strip()
|
| if body.startswith("[") and body.endswith("]"):
|
| try:
|
| parsed = json.loads(body)
|
| normalized = normalize_function_entries(parsed)
|
| if normalized:
|
| return normalized
|
| except json.JSONDecodeError:
|
| pass
|
| output: List[Any] = []
|
| object_bodies = [match.group(1) for match in re.finditer(r"\{([^{}]*)\}", body)] or [body]
|
| for object_body in object_bodies:
|
| name_match = re.search(r'"name"\s*:\s*"([^"\\]*(?:\\.[^"\\]*)*)"', object_body)
|
| if not name_match:
|
| continue
|
| name = clean_jsonish_string(name_match.group(1))
|
| evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", object_body), max_items=8)
|
| if name:
|
| output.append({"name": name, "evidence_ids": evidence_ids})
|
| return output
|
|
|
|
|
| def jsonish_prediction_object(text: str) -> Optional[Dict[str, Any]]:
|
| repaired = repair_generated_json_text(text)
|
| summary = clean_jsonish_string(
|
| jsonish_field_payload(repaired, "画面总结", ("可见文字", "互动数据", "功能入口", "关键证据"))
|
| )
|
| visible_text = jsonish_string_list(
|
| jsonish_field_payload(repaired, "可见文字", ("互动数据", "功能入口", "关键证据")),
|
| max_items=32,
|
| )
|
| interaction_data = jsonish_string_list(
|
| jsonish_field_payload(repaired, "互动数据", ("功能入口", "关键证据")),
|
| max_items=16,
|
| )
|
| functions = jsonish_function_entries(jsonish_field_payload(repaired, "功能入口", ("关键证据",)))
|
| evidence_ids = jsonish_string_list(jsonish_field_payload(repaired, "关键证据", ()), max_items=16)
|
| if not (summary or visible_text or interaction_data or functions or evidence_ids):
|
| return None
|
| return {
|
| "画面总结": summary,
|
| "可见文字": visible_text,
|
| "互动数据": interaction_data,
|
| "功能入口": functions,
|
| "关键证据": evidence_ids,
|
| }
|
|
|
|
|
| NATURAL_OUTPUT_KEYS = ("画面总结", "可见文字", "互动数据", "功能入口", "关键证据") |
|
|
|
|
| def has_natural_title_prefix(text: str) -> bool: |
| return re.search(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", str(text or "")) is not None |
|
|
|
|
| def strip_natural_title_prefix(text: str) -> Tuple[str, bool]: |
| raw = str(text or "").strip().replace("\r\n", "\n").replace("\r", "\n") |
| if not raw or not has_natural_title_prefix(raw): |
| return raw, False |
| prefix_match = re.match(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", raw) |
| if not prefix_match: |
| return raw, False |
| key_pattern = r"(" + "|".join(re.escape(key) for key in NATURAL_OUTPUT_KEYS) + r")\s*[::]" |
| key_matches = list(re.finditer(key_pattern, raw)) |
| if not key_matches: |
| return re.sub(r"^\s*(?:Title|title|标题)(?:\s*[::]|\s+)", "", raw).strip(), True |
| first_key = key_matches[0] |
| prefix_body = safe_text(raw[prefix_match.end() : first_key.start()]) |
| rest = raw[first_key.start() :].strip() |
| if first_key.group(1) == "画面总结": |
| return rest, True |
| if prefix_body: |
| return f"画面总结:{prefix_body}\n{rest}", True |
| return rest, True |
|
|
|
|
| def natural_field_payload(text: str, key: str, following_keys: Iterable[str]) -> str: |
| match = re.search(re.escape(key) + r"\s*[::]", text) |
| if not match: |
| return ""
|
| tail = text[match.end() :]
|
| end = len(tail)
|
| for following_key in following_keys:
|
| next_match = re.search(re.escape(following_key) + r"\s*[::]", tail)
|
| if next_match:
|
| end = min(end, next_match.start())
|
| return tail[:end].strip()
|
|
|
|
|
| def split_natural_items(payload: str, max_items: int = 64) -> List[str]:
|
| payload = safe_text(payload)
|
| if not payload or payload in {"无", "暂无", "没有", "[]"}:
|
| return []
|
| parts = re.split(r"[\n;;|]+|(?<!\d)、", payload)
|
| if len(parts) == 1:
|
| parts = re.split(r"\s*,\s*", payload)
|
| return dedupe_texts(parts, max_items=max_items)
|
|
|
|
|
| def natural_function_entries(payload: str) -> List[Any]:
|
| output: List[Any] = []
|
| for item in split_natural_items(payload, max_items=32):
|
| evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", normalize_evidence_id(item)), max_items=8)
|
| name = re.sub(r"(?证据[::].*?)?$", "", item).strip()
|
| name = re.sub(r"\(\s*证据[::].*?\)\s*$", "", name).strip()
|
| if name:
|
| output.append({"name": name, "evidence_ids": evidence_ids})
|
| return output
|
|
|
|
|
| def natural_prediction_from_text(text: str) -> Dict[str, Any]: |
| text, _ = strip_natural_title_prefix(text) |
| fields = { |
| key: natural_field_payload(text, key, NATURAL_OUTPUT_KEYS[index + 1 :])
|
| for index, key in enumerate(NATURAL_OUTPUT_KEYS)
|
| }
|
| summary = safe_text(fields.get("画面总结"))
|
| if not summary:
|
| first_line = re.split(r"[\n。]", text, maxsplit=1)[0]
|
| summary = safe_text(first_line)
|
| visible_text = split_natural_items(fields.get("可见文字", ""), max_items=32)
|
| interaction_data = split_natural_items(fields.get("互动数据", ""), max_items=16)
|
| functions = natural_function_entries(fields.get("功能入口", ""))
|
| evidence_ids = dedupe_texts(re.findall(r"\bE\d{2,8}\b", normalize_evidence_id(fields.get("关键证据", ""))), max_items=16)
|
| if not evidence_ids:
|
| evidence_ids = dedupe_texts(
|
| eid for func in functions if isinstance(func, dict) for eid in func.get("evidence_ids", [])
|
| )
|
| return {
|
| "画面总结": summary,
|
| "可见文字": visible_text,
|
| "互动数据": interaction_data,
|
| "功能入口": functions,
|
| "关键证据": evidence_ids,
|
| }
|
|
|
|
|
| def safe_json_loads_with_repair(text: str) -> Tuple[Optional[Dict[str, Any]], bool, bool, bool]:
|
| obj, strict_ok = json_object_from_text(text)
|
| if strict_ok:
|
| return obj, True, False, True
|
| repaired_text = repair_generated_json_text(text)
|
| obj, repaired_json_ok = json_object_from_text(repaired_text)
|
| if repaired_json_ok:
|
| return obj, True, repaired_text.strip() != str(text or "").strip(), False
|
| obj = jsonish_prediction_object(repaired_text)
|
| if obj is not None:
|
| return obj, True, True, False
|
| return None, False, False, False
|
|
|
|
|
| def safe_json_loads(text: str) -> Tuple[Optional[Dict[str, Any]], bool]:
|
| obj, ok, _, _ = safe_json_loads_with_repair(text)
|
| return obj, ok
|
|
|
|
|
| def load_seq2seq_tokenizer(model_name_or_path: str, model_hint: str = ""): |
| name = f"{model_name_or_path} {model_hint}".lower() |
| load_path = str(model_name_or_path) |
| path = Path(load_path) |
| has_spiece = path.is_dir() and (path / "spiece.model").exists() |
| if has_spiece or "mt5" in name or "t5" in name: |
| hint_path = Path(str(model_hint)) if model_hint else None |
| if path.is_dir() and not (path / "spiece.model").exists() and hint_path and (hint_path / "spiece.model").exists(): |
| load_path = str(hint_path) |
| return T5Tokenizer.from_pretrained(load_path, fix_mistral_regex=True) |
| return AutoTokenizer.from_pretrained(model_name_or_path, use_fast=False) |
|
|
|
|
| def char_lcs(a: str, b: str) -> int:
|
| if not a or not b:
|
| return 0
|
| prev = [0] * (len(b) + 1)
|
| for ca in a:
|
| cur = [0]
|
| for j, cb in enumerate(b, start=1):
|
| if ca == cb:
|
| cur.append(prev[j - 1] + 1)
|
| else:
|
| cur.append(max(cur[-1], prev[j]))
|
| prev = cur
|
| return prev[-1]
|
|
|
|
|
| def rouge_l_char(pred: str, ref: str) -> float:
|
| if not pred or not ref:
|
| return 0.0
|
| lcs = char_lcs(pred, ref)
|
| prec = lcs / max(1, len(pred))
|
| rec = lcs / max(1, len(ref))
|
| if prec + rec == 0:
|
| return 0.0
|
| return 2 * prec * rec / (prec + rec)
|
|
|
|
|
| def extract_summary(obj: Optional[Dict[str, Any]]) -> str:
|
| if not obj:
|
| return ""
|
| return safe_text(obj.get("画面总结") or obj.get("summary_zh") or obj.get("summary"))
|
|
|
|
|
| def extract_evidence_ids(obj: Optional[Dict[str, Any]]) -> List[str]:
|
| if not obj:
|
| return []
|
| values = obj.get("关键证据") or obj.get("key_ui_clues") or obj.get("evidence") or []
|
| out: List[str] = []
|
| if isinstance(values, list):
|
| for value in values:
|
| if isinstance(value, str):
|
| normalized = normalize_evidence_id(value)
|
| if normalized:
|
| out.append(normalized)
|
| elif isinstance(value, dict):
|
| out.extend(normalize_evidence_id(x) for x in value.get("evidence_ids", []) if normalize_evidence_id(x))
|
| return out
|
|
|
|
|
| def extract_function_entries(obj: Optional[Dict[str, Any]]) -> List[Any]:
|
| if not isinstance(obj, dict):
|
| return []
|
| values = obj.get("功能入口") or obj.get("ui_functions") or []
|
| return values if isinstance(values, list) else []
|
|
|
|
|
| def extract_function_names(obj: Optional[Dict[str, Any]]) -> List[str]:
|
| names = []
|
| for value in extract_function_entries(obj):
|
| if isinstance(value, dict):
|
| names.append(safe_text(value.get("name")))
|
| else:
|
| names.append(safe_text(value))
|
| return [name for name in names if name]
|
|
|
|
|
| def extract_function_evidence_ids(target: Dict[str, Any]) -> List[str]:
|
| ids: List[str] = []
|
| for value in extract_function_entries(target):
|
| if isinstance(value, dict):
|
| ids.extend(normalize_evidence_id(eid) for eid in value.get("evidence_ids", []) if normalize_evidence_id(eid))
|
| return ids
|
|
|
|
|
| def extract_named_function_evidence_ids(target: Dict[str, Any], keyword: str) -> List[str]:
|
| ids: List[str] = []
|
| for value in extract_function_entries(target):
|
| if isinstance(value, dict) and keyword in safe_text(value.get("name")):
|
| ids.extend(normalize_evidence_id(eid) for eid in value.get("evidence_ids", []) if normalize_evidence_id(eid))
|
| return ids
|
|
|
|
|
| def has_search_function(obj: Optional[Dict[str, Any]]) -> bool:
|
| return any("搜索" in name for name in extract_function_names(obj))
|
|
|
|
|
| def count_search_functions(obj: Optional[Dict[str, Any]]) -> int:
|
| return sum(1 for name in extract_function_names(obj) if "搜索" in name)
|
|
|
|
|
| def count_bare_search_functions(obj: Optional[Dict[str, Any]]) -> int:
|
| return sum(1 for name in extract_function_names(obj) if is_bare_search_function_name(name))
|
|
|
|
|
| def normalized_prediction_obj(obj: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| if not isinstance(obj, dict):
|
| return None
|
| return {
|
| "画面总结": safe_text(obj.get("画面总结") or obj.get("summary_zh") or obj.get("summary")),
|
| "可见文字": obj.get("可见文字") or obj.get("visible_text") or [],
|
| "互动数据": obj.get("互动数据") or obj.get("interaction_data") or [],
|
| "功能入口": obj.get("功能入口") or obj.get("ui_functions") or [],
|
| "关键证据": obj.get("关键证据") or obj.get("key_ui_clues") or obj.get("evidence") or [],
|
| }
|
|
|
|
|
| def element_id(item: Dict[str, Any]) -> str:
|
| return safe_text(item.get("id") or item.get("ocr_id"))
|
|
|
|
|
| def visible_text_from_row(row: Dict[str, Any], max_items: int = 16) -> List[str]:
|
| values: List[str] = []
|
| seen = set()
|
| for group_name in ("ocr_items", "ui_items"):
|
| for item in row.get(group_name, []) or []:
|
| text = safe_text(item.get("text"))
|
| if not text or text in seen or re.fullmatch(r"E\d+", text):
|
| continue
|
| if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z0-9]", text):
|
| continue
|
| if re.fullmatch(r"[\s\-_.::/]+", text):
|
| continue
|
| values.append(text[:80])
|
| seen.add(text)
|
| if len(values) >= max_items:
|
| return values
|
| return values
|
|
|
|
|
| def visible_texts_from_items(items: Iterable[Dict[str, Any]], max_items: int = 16) -> List[str]: |
| values: List[str] = [] |
| seen = set() |
| for item in items or []: |
| text = safe_text(item.get("text"))
|
| if not text or text in seen or re.fullmatch(r"E\d+", text):
|
| continue
|
| if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z0-9]", text):
|
| continue
|
| if re.fullmatch(r"[\s\-_.::/]+", text):
|
| continue
|
| values.append(text[:80])
|
| seen.add(text)
|
| if len(values) >= max_items:
|
| break |
| return values |
|
|
|
|
| def maybe_dropout_texts(values: List[str], dropout_rate: float) -> List[str]: |
| if dropout_rate <= 0.0 or not values: |
| return values |
| kept = [value for value in values if random.random() >= dropout_rate] |
| if kept: |
| return kept |
| return [random.choice(values)] |
|
|
|
|
| def build_context_text(row: Dict[str, Any], args: argparse.Namespace, screen_text_dropout_rate: float = 0.0) -> str: |
| app_text = safe_text(row.get("app")) |
| instruction_text = safe_text(row.get("instruction")) |
| context_format = str(getattr(args, "context_text_format", "rich") or "rich").lower() |
| if context_format == "text_only": |
| parts = [f"app: {app_text}", f"task: {instruction_text}"] |
| if bool(getattr(args, "context_include_screen_text", False)): |
| max_screen_items = int(getattr(args, "context_screen_text_items", 32) or 32) |
| ocr_text = visible_texts_from_items(row.get("ocr_items") or [], max_items=max_screen_items) |
| ui_text = visible_texts_from_items(row.get("ui_items") or [], max_items=max_screen_items) |
| ocr_text = maybe_dropout_texts(ocr_text, screen_text_dropout_rate) |
| ui_text = maybe_dropout_texts(ui_text, screen_text_dropout_rate) |
| if ocr_text: |
| parts.append("ocr: " + " | ".join(ocr_text)) |
| if ui_text: |
| parts.append("ui: " + " | ".join(ui_text)) |
| return "\n".join(parts) |
|
|
| context_text = f"应用:{app_text} 任务:{instruction_text}" |
| if bool(getattr(args, "context_include_screen_text", False)): |
| screen_text = visible_text_from_row(row, max_items=int(getattr(args, "context_screen_text_items", 32) or 32)) |
| screen_text = maybe_dropout_texts(screen_text, screen_text_dropout_rate) |
| if screen_text: |
| context_text = f"{context_text} 屏幕文字:{' | '.join(screen_text)}" |
| return context_text |
|
|
|
|
| def prediction_from_summary(row: Dict[str, Any], summary_text: str) -> Dict[str, Any]: |
|
|
|
|
|
|
| return {
|
| "画面总结": safe_text(summary_text),
|
| "可见文字": [],
|
| "互动数据": [],
|
| "功能入口": [],
|
| "关键证据": [],
|
| }
|
|
|
|
|
| def repair_prediction_with_context(row: Dict[str, Any], pred_obj: Optional[Dict[str, Any]]) -> Tuple[Dict[str, Any], bool]:
|
| app = safe_text(row.get("app")) or "移动应用"
|
| instruction = safe_text(row.get("instruction"))
|
| has_search_task = row_has_search_task(row)
|
| has_search_evidence = row_has_visible_search_evidence(row)
|
| repaired = normalized_prediction_obj(pred_obj)
|
| changed = repaired is None
|
| if repaired is None:
|
| repaired = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []}
|
|
|
| summary = safe_text(repaired.get("画面总结"))
|
| search_overuse = ("搜索页面" in summary or "搜索结果" in summary) and "搜索" not in instruction
|
| generic_app = "App的App" in summary or "手机App的App" in summary or summary.count("App") >= 3
|
| needs_summary = not summary or app not in summary or generic_app or search_overuse
|
| if needs_summary:
|
| if instruction:
|
| summary = f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。 当前任务语境是:{instruction}。"
|
| else:
|
| summary = f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。"
|
| repaired["画面总结"] = summary
|
| changed = True
|
|
|
| visible = repaired.get("可见文字") or []
|
| if isinstance(visible, list):
|
| cleaned_visible = [value for value in visible if not re.fullmatch(r"E\d+", safe_text(value))][:16]
|
| if cleaned_visible != visible:
|
| repaired["可见文字"] = cleaned_visible
|
| changed = True
|
| else:
|
| repaired["可见文字"] = []
|
| changed = True
|
|
|
| if not isinstance(repaired.get("互动数据"), list):
|
| repaired["互动数据"] = []
|
| changed = True
|
| functions = repaired.get("功能入口") or []
|
| if isinstance(functions, list):
|
| cleaned_functions = []
|
| for function in functions:
|
| name = safe_text(function.get("name") if isinstance(function, dict) else function)
|
| if is_generic_function_name(name, row):
|
| changed = True
|
| continue
|
| if "搜索" in name and not (has_search_task and has_search_evidence):
|
| changed = True
|
| continue
|
| cleaned_functions.append(function)
|
| if cleaned_functions != functions:
|
| repaired["功能入口"] = cleaned_functions
|
| else:
|
| repaired["功能入口"] = []
|
| changed = True
|
| evidence = repaired.get("关键证据") or []
|
| if not isinstance(evidence, list):
|
| evidence = []
|
| if not evidence:
|
| evidence = list(row.get("weak_evidence_ids") or [])[:8]
|
| repaired["关键证据"] = evidence
|
| changed = True
|
| return repaired, changed
|
|
|
|
|
| def row_has_visible_search_evidence(row: Dict[str, Any]) -> bool:
|
| for item in list(row.get("ui_items", [])) + list(row.get("ocr_items", [])):
|
| text = safe_text(item.get("text"))
|
| item_type = safe_text(item.get("type")).lower()
|
| if "搜索" in text or "search" in item_type:
|
| return True
|
| if text in {"搜", "搜索框"} or "输入" in text:
|
| return True
|
| return False
|
|
|
|
|
| def row_has_search_context(row: Dict[str, Any]) -> bool:
|
| instruction = safe_text(row.get("instruction"))
|
| screen_text = "".join(safe_text(item.get("text")) for item in (row.get("ui_items") or []))
|
| screen_text += "".join(safe_text(item.get("text")) for item in (row.get("ocr_items") or []))
|
| return "搜索" in instruction or "搜索" in screen_text
|
|
|
|
|
| def row_has_search_task(row: Dict[str, Any]) -> bool:
|
| instruction = safe_text(row.get("instruction"))
|
| return "搜索" in instruction or "查找" in instruction or "搜" in instruction
|
|
|
|
|
| def is_generic_function_name(name: str, row: Dict[str, Any]) -> bool:
|
| text = safe_text(name)
|
| app = safe_text(row.get("app"))
|
| if not text:
|
| return True
|
| generic_names = {
|
| "入口",
|
| "功能入口",
|
| "功能反馈",
|
| "打开应用",
|
| "进入应用",
|
| "打开App",
|
| "进入App",
|
| }
|
| if text in generic_names:
|
| return True
|
| if app and text in {f"{app}入口", f"进入{app}", f"打开{app}", f"进入{app}App", f"打开{app}App"}:
|
| return True
|
| if re.fullmatch(r"(进入|打开).{1,12}(App|应用)?", text):
|
| return True
|
| return False
|
|
|
|
|
| def is_bare_search_function_name(name: str) -> bool:
|
| return safe_text(name) == "搜索"
|
|
|
|
|
| def is_search_ui_item(item: Dict[str, Any]) -> bool:
|
| text = safe_text(item.get("text"))
|
| item_type = safe_text(item.get("type")).lower()
|
| if "搜索" in text or "search" in item_type:
|
| return True
|
| return text in {"搜", "搜索框"} or "输入" in text
|
|
|
|
|
| def is_structured_search_candidate_item(item: Dict[str, Any], strict: bool = False) -> bool:
|
| if not is_search_ui_item(item):
|
| return False
|
| if not strict:
|
| return True
|
| text = safe_text(item.get("text"))
|
| if text in {"搜索", "搜", "搜索框"}:
|
| return True
|
| passive_terms = [
|
| "历史搜索",
|
| "搜索历史",
|
| "搜索发现",
|
| "深度搜索",
|
| "AI搜索",
|
| "热门搜索",
|
| "最近搜索",
|
| "相关搜索",
|
| "搜索记录",
|
| "清除搜索记录",
|
| "搜索结果",
|
| "搜索来源",
|
| "搜索推荐",
|
| "根据你的搜索",
|
| ]
|
| if any(term in text for term in passive_terms):
|
| return False
|
| control_terms = [
|
| "搜索框",
|
| "搜索、提问",
|
| "搜索地点",
|
| "搜索店内",
|
| "搜索商品",
|
| "搜索景点",
|
| "搜索机票",
|
| "附近搜索",
|
| "请输入",
|
| "输入",
|
| ]
|
| if any(term in text for term in control_terms):
|
| return True
|
| return text.startswith("搜索") and len(text) <= 8
|
|
|
|
|
| def structured_function_name(item: Dict[str, Any], row: Dict[str, Any]) -> str:
|
| text = safe_text(item.get("text"))
|
| item_type = safe_text(item.get("type")).lower()
|
| if is_search_ui_item(item):
|
| return "搜索功能入口"
|
| if text:
|
| if len(text) <= 12:
|
| return text
|
| if "购物车" in text:
|
| return "购物车"
|
| if "消息" in text:
|
| return "消息入口"
|
| if "设置" in text:
|
| return "设置入口"
|
| return text[:12]
|
| if item.get("is_action_target"):
|
| instruction = safe_text(row.get("instruction"))
|
| for keyword in ["商城", "购物车", "消息", "订单", "收藏", "关注", "分享", "返回"]:
|
| if keyword in instruction:
|
| return f"{keyword}入口" if keyword not in {"收藏", "关注", "分享", "返回"} else keyword
|
| if "button" in item_type:
|
| return "按钮入口"
|
| return "功能入口"
|
|
|
|
|
| def apply_structured_function_predictions(
|
| row: Dict[str, Any],
|
| pred_obj: Optional[Dict[str, Any]],
|
| function_scores: torch.Tensor,
|
| search_scores: torch.Tensor,
|
| args: argparse.Namespace,
|
| ) -> Dict[str, Any]:
|
| mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower()
|
| if mode == "decoder":
|
| return pred_obj if isinstance(pred_obj, dict) else normalized_prediction_obj(pred_obj)
|
| structured = normalized_prediction_obj(pred_obj)
|
| if structured is None:
|
| structured = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []}
|
| function_threshold = float(getattr(args, "structured_function_threshold", 0.5) or 0.5)
|
| search_threshold = float(getattr(args, "structured_search_threshold", function_threshold) or function_threshold)
|
| max_functions = int(getattr(args, "structured_max_functions", 12) or 12)
|
| has_search_task = row_has_search_task(row)
|
| has_search_evidence = row_has_visible_search_evidence(row)
|
| strict_search_candidates = bool(getattr(args, "structured_strict_search_candidates", False))
|
| items = (row.get("ui_items") or [])[: int(getattr(args, "max_elements", len(function_scores)))]
|
| candidates: List[Tuple[float, Dict[str, Any]]] = []
|
| for idx, item in enumerate(items[: len(function_scores)]):
|
| function_score = float(function_scores[idx])
|
| search_score = float(search_scores[idx]) if idx < len(search_scores) else 0.0
|
| raw_item_is_search = is_search_ui_item(item)
|
| item_is_search = is_structured_search_candidate_item(item, strict=strict_search_candidates)
|
| if raw_item_is_search and not item_is_search:
|
| continue
|
| if item_is_search:
|
| keep = has_search_task and has_search_evidence and search_score >= search_threshold
|
| else:
|
| keep = function_score >= function_threshold
|
| if not keep:
|
| continue
|
| evidence_id = safe_text(item.get("id") or item.get("ocr_id"))
|
| if not evidence_id:
|
| continue
|
| name = structured_function_name(item, row)
|
| if is_generic_function_name(name, row):
|
| continue
|
| if "搜索" in name and not (has_search_task and has_search_evidence):
|
| continue
|
| score = search_score if item_is_search else function_score
|
| candidates.append((score, {"name": name, "evidence_ids": [evidence_id]}))
|
| candidates.sort(key=lambda value: value[0], reverse=True)
|
| functions: List[Dict[str, Any]] = []
|
| seen = set()
|
| for _, function in candidates:
|
| key = (function["name"], tuple(function.get("evidence_ids", [])))
|
| if key in seen:
|
| continue
|
| seen.add(key)
|
| functions.append(function)
|
| if len(functions) >= max_functions:
|
| break
|
| structured["功能入口"] = functions
|
| return structured
|
|
|
|
|
| def apply_structured_evidence_predictions(
|
| row: Dict[str, Any],
|
| pred_obj: Optional[Dict[str, Any]],
|
| evidence_scores: torch.Tensor,
|
| args: argparse.Namespace,
|
| ) -> Dict[str, Any]:
|
| mode = str(getattr(args, "structured_evidence_mode", "decoder") or "decoder").lower()
|
| structured = normalized_prediction_obj(pred_obj)
|
| if structured is None:
|
| structured = {"画面总结": "", "可见文字": [], "互动数据": [], "功能入口": [], "关键证据": []}
|
| if mode == "decoder":
|
| return structured
|
| if mode != "heads":
|
| raise ValueError("structured_evidence_mode must be one of: decoder, heads")
|
| threshold = float(getattr(args, "structured_evidence_threshold", 0.5) or 0.5)
|
| max_evidence = int(getattr(args, "structured_max_evidence", 8) or 8)
|
| fallback_top1 = bool(getattr(args, "structured_evidence_fallback_top1", True))
|
| items = (row.get("ui_items") or [])[: int(getattr(args, "max_elements", len(evidence_scores)))]
|
| candidates: List[Tuple[float, str]] = []
|
| for idx, item in enumerate(items[: len(evidence_scores)]):
|
| eid = element_id(item)
|
| if not eid:
|
| continue
|
| score = float(evidence_scores[idx])
|
| candidates.append((score, eid))
|
| candidates.sort(key=lambda value: value[0], reverse=True)
|
| selected: List[str] = []
|
| seen = set()
|
| for score, eid in candidates:
|
| if score < threshold:
|
| continue
|
| if eid in seen:
|
| continue
|
| selected.append(eid)
|
| seen.add(eid)
|
| if len(selected) >= max_evidence:
|
| break
|
| if not selected and fallback_top1 and candidates:
|
| selected.append(candidates[0][1])
|
| structured["关键证据"] = selected
|
| return structured
|
|
|
|
|
| def build_context_summary(row: Dict[str, Any]) -> str:
|
| app = safe_text(row.get("app")) or "移动应用"
|
| instruction = safe_text(row.get("instruction"))
|
| if instruction:
|
| return f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。 当前任务语境是:{instruction}。"
|
| return f"这是一个{app}界面,页面包含若干文字内容和可操作的 UI 入口。"
|
|
|
|
|
| def canonicalize_target_with_context(
|
| row: Dict[str, Any],
|
| target: Dict[str, Any],
|
| drop_bare_search_functions: bool = False,
|
| ) -> Dict[str, Any]:
|
| canonical = dict(target or {})
|
|
|
|
|
|
|
| existing_summary = safe_text(canonical.get("summary_zh"))
|
| if not existing_summary:
|
| canonical["summary_zh"] = build_context_summary(row)
|
|
|
| visible = canonical.get("visible_text") or []
|
| if isinstance(visible, list):
|
| canonical["visible_text"] = [value for value in visible if not re.fullmatch(r"E\d+", safe_text(value))][:16]
|
| else:
|
| canonical["visible_text"] = []
|
|
|
| functions = canonical.get("ui_functions") or []
|
| if isinstance(functions, list):
|
| has_search_task = row_has_search_task(row)
|
| has_search_evidence = row_has_visible_search_evidence(row)
|
| cleaned_functions = []
|
| for function in functions:
|
| name = safe_text(function.get("name") if isinstance(function, dict) else function)
|
| if is_generic_function_name(name, row):
|
| continue
|
| if drop_bare_search_functions and is_bare_search_function_name(name):
|
| continue
|
| if "搜索" in name and not (has_search_task and has_search_evidence):
|
| continue
|
| cleaned_functions.append(function)
|
| canonical["ui_functions"] = cleaned_functions[:12]
|
| else:
|
| canonical["ui_functions"] = []
|
|
|
| if not isinstance(canonical.get("interaction_data"), list):
|
| canonical["interaction_data"] = []
|
| evidence = canonical.get("key_ui_clues") or []
|
| if not isinstance(evidence, list):
|
| evidence = []
|
| if not evidence:
|
| evidence = list(row.get("weak_evidence_ids") or [])[:8]
|
| canonical["key_ui_clues"] = evidence[:12]
|
| return canonical
|
|
|
|
|
| class RichScreenshotDataset(Dataset): |
| def __init__(self, path: str, max_samples: int = 0, sample_seed: Optional[int] = None):
|
| self.path = Path(path)
|
| self.rows = list(read_jsonl(self.path))
|
| if max_samples and max_samples < len(self.rows):
|
| if sample_seed is None:
|
| self.rows = self.rows[:max_samples]
|
| else:
|
| rng = random.Random(sample_seed)
|
| indices = sorted(rng.sample(range(len(self.rows)), max_samples))
|
| self.rows = [self.rows[index] for index in indices]
|
| if not self.rows:
|
| raise ValueError(f"No rows loaded from {path}")
|
|
|
| def __len__(self) -> int:
|
| return len(self.rows)
|
|
|
| def __getitem__(self, idx: int) -> Dict[str, Any]: |
| return self.rows[idx] |
|
|
|
|
| def dataset_diagnostics(rows: List[Dict[str, Any]]) -> Dict[str, Any]: |
| ocr_counts: List[int] = [] |
| ui_counts: List[int] = [] |
| target_chars: List[int] = [] |
| missing_target = 0 |
| short_summary = 0 |
| missing_image_path = 0 |
| missing_image_file = 0 |
| missing_ocr_items = 0 |
| missing_ui_items = 0 |
| for row in rows: |
| target = row.get("target") |
| if not isinstance(target, dict): |
| missing_target += 1 |
| summary = "" |
| else: |
| summary = safe_text(target.get("summary_zh")) |
| target_chars.append(len(summary)) |
| if len(summary) < 8: |
| short_summary += 1 |
| image_path = safe_text(row.get("image_path")) |
| if not image_path: |
| missing_image_path += 1 |
| elif not Path(image_path).exists(): |
| missing_image_file += 1 |
| ocr_count = len(row.get("ocr_items") or []) |
| ui_count = len(row.get("ui_items") or []) |
| ocr_counts.append(ocr_count) |
| ui_counts.append(ui_count) |
| if ocr_count <= 0: |
| missing_ocr_items += 1 |
| if ui_count <= 0: |
| missing_ui_items += 1 |
|
|
| row_count = len(rows) |
|
|
| def mean(values: List[int]) -> float: |
| return float(np.mean(values)) if values else 0.0 |
|
|
| return { |
| "rows": row_count, |
| "missing_target": missing_target, |
| "short_summary": short_summary, |
| "missing_image_path": missing_image_path, |
| "missing_image_file": missing_image_file, |
| "missing_ocr_items": missing_ocr_items, |
| "missing_ui_items": missing_ui_items, |
| "missing_ocr_rate": missing_ocr_items / max(1, row_count), |
| "missing_ui_rate": missing_ui_items / max(1, row_count), |
| "ocr_items_mean": mean(ocr_counts), |
| "ocr_items_max": int(max(ocr_counts)) if ocr_counts else 0, |
| "ui_items_mean": mean(ui_counts), |
| "ui_items_max": int(max(ui_counts)) if ui_counts else 0, |
| "summary_chars_mean": mean(target_chars), |
| "summary_chars_max": int(max(target_chars)) if target_chars else 0, |
| } |
|
|
|
|
| def validate_dataset_for_training(split_name: str, diagnostics: Dict[str, Any], args: argparse.Namespace) -> None: |
| if not bool(getattr(args, "strict_data_checks", True)): |
| return |
| errors: List[str] = [] |
| if diagnostics["missing_target"]: |
| errors.append(f"{diagnostics['missing_target']} rows are missing target") |
| if diagnostics["short_summary"]: |
| errors.append(f"{diagnostics['short_summary']} rows have summary_zh shorter than 8 chars") |
| if diagnostics["missing_image_path"]: |
| errors.append(f"{diagnostics['missing_image_path']} rows are missing image_path") |
| vision_enabled = not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only" |
| if vision_enabled and diagnostics["missing_image_file"]: |
| errors.append(f"{diagnostics['missing_image_file']} image files do not exist") |
| needs_ui_context = bool(getattr(args, "direct_element_tokens", False)) or int(getattr(args, "max_elements", 0) or 0) > 0 |
| needs_screen_text = bool(getattr(args, "context_include_screen_text", False)) |
| if needs_ui_context and diagnostics["missing_ui_rate"] > 0.05: |
| errors.append(f"{diagnostics['missing_ui_items']} rows are missing ui_items") |
| if needs_screen_text and diagnostics["missing_ocr_rate"] > 0.05: |
| errors.append(f"{diagnostics['missing_ocr_items']} rows are missing ocr_items") |
| if errors: |
| raise ValueError(f"{split_name} dataset failed strict data checks: " + "; ".join(errors)) |
|
|
|
|
| def token_length_summary(values: List[int], max_length: int) -> Dict[str, Any]: |
| if not values: |
| return { |
| "count": 0, |
| "mean": 0.0, |
| "p50": 0, |
| "p90": 0, |
| "p95": 0, |
| "p99": 0, |
| "max": 0, |
| "over_max": 0, |
| "over_max_rate": 0.0, |
| "at_or_over_max": 0, |
| "at_or_over_max_rate": 0.0, |
| } |
| ordered = sorted(values) |
|
|
| def pct(percent: float) -> int: |
| index = int(round((len(ordered) - 1) * percent / 100.0)) |
| return int(ordered[max(0, min(len(ordered) - 1, index))]) |
|
|
| over_max = sum(value > max_length for value in values) |
| at_or_over_max = sum(value >= max_length for value in values) |
| return { |
| "count": len(values), |
| "mean": float(np.mean(values)), |
| "p50": pct(50), |
| "p90": pct(90), |
| "p95": pct(95), |
| "p99": pct(99), |
| "max": int(max(values)), |
| "configured_max": int(max_length), |
| "over_max": int(over_max), |
| "over_max_rate": over_max / max(1, len(values)), |
| "at_or_over_max": int(at_or_over_max), |
| "at_or_over_max_rate": at_or_over_max / max(1, len(values)), |
| } |
|
|
|
|
| def tokenizer_diagnostics(rows: List[Dict[str, Any]], tokenizer, args: argparse.Namespace) -> Dict[str, Any]: |
| target_lengths: List[int] = [] |
| context_lengths: List[int] = [] |
| target_schema = getattr(args, "target_schema", "zh") |
| max_target_tokens = int(getattr(args, "max_target_tokens", 384) or 384) |
| max_context_tokens = int(getattr(args, "max_context_tokens", 64) or 64) |
| for row in rows: |
| target = row.get("target") or {} |
| if bool(getattr(args, "canonicalize_targets", False)): |
| target = canonicalize_target_with_context( |
| row, |
| target, |
| drop_bare_search_functions=bool(getattr(args, "drop_bare_search_functions", False)), |
| ) |
| target_text = target_to_text(target, target_schema) |
| context_text = build_context_text(row, args) |
| target_lengths.append(len(tokenizer(target_text, add_special_tokens=True).input_ids)) |
| context_lengths.append(len(tokenizer(context_text, add_special_tokens=True).input_ids)) |
| return { |
| "target_tokens": token_length_summary(target_lengths, max_target_tokens), |
| "context_tokens": token_length_summary(context_lengths, max_context_tokens), |
| } |
|
|
|
|
| def validate_token_lengths(split_name: str, diagnostics: Dict[str, Any], args: argparse.Namespace) -> None: |
| if not bool(getattr(args, "strict_data_checks", True)): |
| return |
| max_allowed_rate = float(getattr(args, "max_target_truncation_rate", 0.01) or 0.0) |
| target_stats = diagnostics.get("target_tokens") or {} |
| over_rate = float(target_stats.get("over_max_rate", 0.0) or 0.0) |
| if over_rate > max_allowed_rate: |
| raise ValueError( |
| f"{split_name} target token truncation rate {over_rate:.2%} exceeds " |
| f"max_target_truncation_rate={max_allowed_rate:.2%}; " |
| f"max_target_tokens={target_stats.get('configured_max')}, " |
| f"p95={target_stats.get('p95')}, p99={target_stats.get('p99')}, max={target_stats.get('max')}" |
| ) |
| if int(getattr(args, "eval_max_new_tokens", 0) or 0) < int(getattr(args, "max_target_tokens", 0) or 0): |
| raise ValueError("eval_max_new_tokens must be >= max_target_tokens for natural_zh validation.") |
|
|
|
|
| class RichCollator: |
| def __init__(self, tokenizer, image_processor, args: argparse.Namespace, is_training: bool = False): |
| self.tokenizer = tokenizer |
| self.image_processor = image_processor |
| self.args = args |
| self.is_training = is_training |
| self.use_vision = not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only" |
|
|
| def _load_images(self, image_path: str) -> List[Image.Image]:
|
| path = Path(image_path)
|
| try:
|
| image = Image.open(path).convert("RGB")
|
| except Exception:
|
| image = Image.new("RGB", (self.args.image_size, self.args.image_size), color=(245, 245, 245))
|
| images = [image]
|
| if self.args.num_vertical_crops > 0:
|
| width, height = image.size
|
| crops = self.args.num_vertical_crops
|
| for i in range(crops):
|
| top = int(height * i / crops)
|
| bottom = int(height * (i + 1) / crops)
|
| images.append(image.crop((0, top, width, bottom)))
|
| return images
|
|
|
| def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]: |
| bsz = len(batch) |
| if self.use_vision: |
| flat_images: List[Image.Image] = [] |
| crop_counts: List[int] = [] |
| for row in batch: |
| images = self._load_images(row.get("image_path", "")) |
| crop_counts.append(len(images)) |
| flat_images.extend(images) |
| image_kwargs = { |
| "return_tensors": "pt", |
| "do_resize": True, |
| "size": {"height": self.args.image_size, "width": self.args.image_size}, |
| } |
| pixel_values = self.image_processor(images=flat_images, **image_kwargs)["pixel_values"] |
| crops = crop_counts[0] |
| pixel_values = pixel_values.view(bsz, crops, *pixel_values.shape[1:]) |
| else: |
| pixel_values = torch.empty(bsz, 0, 3, self.args.image_size, self.args.image_size, dtype=torch.float32) |
|
|
| all_elements: List[List[Dict[str, Any]]] = []
|
| flat_element_texts: List[str] = []
|
| element_mask = torch.zeros(bsz, self.args.max_elements, dtype=torch.bool)
|
| bboxes = torch.zeros(bsz, self.args.max_elements, 4, dtype=torch.float32)
|
| type_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long)
|
| source_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long)
|
| loc_ids = torch.zeros(bsz, self.args.max_elements, dtype=torch.long)
|
| confs = torch.zeros(bsz, self.args.max_elements, 1, dtype=torch.float32)
|
| action_flags = torch.zeros(bsz, self.args.max_elements, 1, dtype=torch.float32)
|
| evidence_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32)
|
| numeric_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32)
|
| ui_function_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32)
|
| search_function_labels = torch.zeros(bsz, self.args.max_elements, dtype=torch.float32)
|
|
|
| type_vocab: Dict[str, int] = {"visual": 0, "text": 1, "button": 2, "label": 3, "text_number": 4}
|
| source_vocab: Dict[str, int] = {"unknown": 0, "cmgui": 1, "ocr": 2, "rule": 3}
|
| loc_vocab: Dict[str, int] = {
|
| "unknown": 0,
|
| "top-left": 1,
|
| "top-center": 2,
|
| "top-right": 3,
|
| "middle-left": 4,
|
| "middle-center": 5,
|
| "middle-right": 6,
|
| "bottom-left": 7,
|
| "bottom-center": 8,
|
| "bottom-right": 9,
|
| }
|
|
|
| target_texts: List[str] = []
|
| context_texts: List[str] = []
|
| section_labels = torch.zeros(bsz, len(SECTION_NAMES), dtype=torch.float32)
|
|
|
| for row_idx, row in enumerate(batch):
|
| target = row.get("target") or {}
|
| if bool(getattr(self.args, "canonicalize_targets", False)):
|
| target = canonicalize_target_with_context(
|
| row,
|
| target,
|
| drop_bare_search_functions=bool(getattr(self.args, "drop_bare_search_functions", False)),
|
| )
|
| target_texts.append(target_to_text(target, getattr(self.args, "target_schema", "zh")))
|
| screen_text_dropout_rate = ( |
| float(getattr(self.args, "context_screen_text_dropout_rate", 0.0) or 0.0) if self.is_training else 0.0 |
| ) |
| context_texts.append(build_context_text(row, self.args, screen_text_dropout_rate=screen_text_dropout_rate)) |
| evidence_ids = set(target.get("key_ui_clues", []) or row.get("weak_evidence_ids", []))
|
| function_evidence = set(extract_function_evidence_ids(target))
|
| search_function_evidence = set(extract_named_function_evidence_ids(target, "搜索"))
|
| interaction_evidence = set()
|
| for interaction in target.get("interaction_data", []) or []:
|
| interaction_evidence.update(interaction.get("evidence_ids", []) or [])
|
| for sec_idx, name in enumerate(SECTION_NAMES):
|
| if target.get(name):
|
| section_labels[row_idx, sec_idx] = 1.0
|
| if bool(getattr(self.args, "task_intent_context", False)):
|
| task_type = "检索" if row_has_search_task(row) else "常规"
|
| context_texts[-1] = f"{context_texts[-1]} 任务类别:{task_type}"
|
| elements = (row.get("ui_items") or [])[: self.args.max_elements]
|
| all_elements.append(elements)
|
| for elem_idx in range(self.args.max_elements):
|
| if elem_idx < len(elements):
|
| elem = elements[elem_idx]
|
| text = safe_text(elem.get("text"))
|
| flat_element_texts.append(f"{elem.get('type','text')} {elem.get('source','')} {text}")
|
| element_mask[row_idx, elem_idx] = True
|
| bbox = elem.get("bbox") or [0, 0, 1, 1]
|
| if len(bbox) == 4:
|
| bboxes[row_idx, elem_idx] = torch.tensor(bbox, dtype=torch.float32)
|
| type_ids[row_idx, elem_idx] = type_vocab.get(safe_text(elem.get("type")), 1)
|
| source_ids[row_idx, elem_idx] = source_vocab.get(safe_text(elem.get("source")), 0)
|
| loc_ids[row_idx, elem_idx] = loc_vocab.get(safe_text(elem.get("location")), 0)
|
| confs[row_idx, elem_idx, 0] = float(elem.get("ocr_conf", elem.get("conf", 1.0)) or 1.0)
|
| action_flags[row_idx, elem_idx, 0] = 1.0 if elem.get("is_action_target") else 0.0
|
| if elem.get("id") in evidence_ids or elem.get("ocr_id") in evidence_ids:
|
| evidence_labels[row_idx, elem_idx] = 1.0
|
| if elem.get("id") in function_evidence or elem.get("ocr_id") in function_evidence:
|
| ui_function_labels[row_idx, elem_idx] = 1.0
|
| if elem.get("id") in search_function_evidence or elem.get("ocr_id") in search_function_evidence:
|
| search_function_labels[row_idx, elem_idx] = 1.0
|
| if elem.get("id") in interaction_evidence or re.search(r"\d", text):
|
| numeric_labels[row_idx, elem_idx] = 1.0
|
| else:
|
| flat_element_texts.append("")
|
|
|
| elem_tokens = self.tokenizer(
|
| flat_element_texts,
|
| padding=True,
|
| truncation=True,
|
| max_length=self.args.max_element_tokens,
|
| return_tensors="pt",
|
| )
|
| elem_input_ids = elem_tokens.input_ids.view(bsz, self.args.max_elements, -1)
|
| elem_attention_mask = elem_tokens.attention_mask.view(bsz, self.args.max_elements, -1)
|
|
|
| context_tokens = self.tokenizer(
|
| context_texts,
|
| padding=True,
|
| truncation=True,
|
| max_length=self.args.max_context_tokens,
|
| return_tensors="pt",
|
| )
|
|
|
| target_tokens = self.tokenizer(
|
| target_texts,
|
| padding=True,
|
| truncation=True,
|
| max_length=self.args.max_target_tokens,
|
| return_tensors="pt",
|
| )
|
| labels = target_tokens.input_ids
|
| labels[labels == self.tokenizer.pad_token_id] = -100
|
|
|
| return {
|
| "rows": batch,
|
| "pixel_values": pixel_values,
|
| "element_input_ids": elem_input_ids,
|
| "element_attention_mask": elem_attention_mask,
|
| "element_mask": element_mask,
|
| "bboxes": bboxes,
|
| "type_ids": type_ids,
|
| "source_ids": source_ids,
|
| "loc_ids": loc_ids,
|
| "confs": confs,
|
| "action_flags": action_flags,
|
| "context_input_ids": context_tokens.input_ids,
|
| "context_attention_mask": context_tokens.attention_mask,
|
| "labels": labels,
|
| "evidence_labels": evidence_labels,
|
| "ui_function_labels": ui_function_labels,
|
| "search_function_labels": search_function_labels,
|
| "section_labels": section_labels,
|
| "numeric_labels": numeric_labels,
|
| }
|
|
|
|
|
| class BottleneckPooler(nn.Module):
|
| def __init__(self, hidden_size: int, num_queries: int = 64, num_heads: int = 8):
|
| super().__init__()
|
| self.queries = nn.Parameter(torch.randn(num_queries, hidden_size) * 0.02)
|
| self.attn = nn.MultiheadAttention(hidden_size, num_heads, batch_first=True)
|
| self.norm = nn.LayerNorm(hidden_size)
|
|
|
| def forward(self, memory: torch.Tensor, key_padding_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| bsz = memory.size(0)
|
| queries = self.queries.unsqueeze(0).expand(bsz, -1, -1)
|
| pooled, _ = self.attn(queries, memory, memory, key_padding_mask=key_padding_mask)
|
| return self.norm(pooled + queries)
|
|
|
|
|
| class RichGroundedModel(nn.Module):
|
| def __init__(self, args: argparse.Namespace):
|
| super().__init__()
|
| self.args = args
|
| self.model_variant = args.model_variant
|
| self.use_vision = not bool(getattr(args, "disable_vision", False)) and self.model_variant != "annotation_only"
|
| self.vision = AutoModel.from_pretrained(args.vision_model) if self.use_vision else None
|
| self.decoder = AutoModelForSeq2SeqLM.from_pretrained(args.decoder_model) |
| if bool(getattr(args, "decoder_gradient_checkpointing", False)) and hasattr(self.decoder, "gradient_checkpointing_enable"): |
| self.decoder.gradient_checkpointing_enable() |
| if hasattr(self.decoder.config, "use_cache"): |
| self.decoder.config.use_cache = False |
| if bool(getattr(args, "freeze_decoder", False)): |
| for param in self.decoder.parameters(): |
| param.requires_grad = False |
| if self.vision is not None and bool(getattr(args, "vision_gradient_checkpointing", False)) and hasattr(self.vision, "gradient_checkpointing_enable"): |
| self.vision.gradient_checkpointing_enable() |
| self.hidden_size = self.decoder.config.d_model
|
| vision_hidden = self._vision_hidden_size() if self.vision is not None else self.hidden_size
|
| self.visual_proj = nn.Linear(vision_hidden, self.hidden_size)
|
| self.elem_text_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
| self.type_emb = nn.Embedding(16, self.hidden_size)
|
| self.source_emb = nn.Embedding(8, self.hidden_size)
|
| self.loc_emb = nn.Embedding(16, self.hidden_size)
|
| self.bbox_proj = nn.Sequential(nn.Linear(6, self.hidden_size), nn.GELU(), nn.Linear(self.hidden_size, self.hidden_size))
|
| self.roi_proj = nn.Linear(vision_hidden, self.hidden_size)
|
| self.context_proj = nn.Linear(self.hidden_size, self.hidden_size)
|
| enc_layer = nn.TransformerEncoderLayer(
|
| d_model=self.hidden_size,
|
| nhead=8,
|
| dim_feedforward=self.hidden_size * 4,
|
| dropout=0.1,
|
| batch_first=True,
|
| norm_first=False,
|
| )
|
| self.layout_encoder = nn.TransformerEncoder(enc_layer, num_layers=2)
|
| fusion_layer = nn.TransformerEncoderLayer(
|
| d_model=self.hidden_size,
|
| nhead=8,
|
| dim_feedforward=self.hidden_size * 4,
|
| dropout=0.1,
|
| batch_first=True,
|
| norm_first=False,
|
| )
|
| self.fusion = nn.TransformerEncoder(fusion_layer, num_layers=2)
|
| self.pooler = BottleneckPooler(self.hidden_size, num_queries=args.bottleneck_queries, num_heads=8)
|
| self.evidence_head = nn.Linear(self.hidden_size, 1)
|
| self.ui_function_head = nn.Linear(self.hidden_size, 1)
|
| self.search_function_head = nn.Linear(self.hidden_size, 1)
|
| self.function_signal_proj = nn.Linear(1, self.hidden_size)
|
| self.search_signal_proj = nn.Linear(1, self.hidden_size)
|
| nn.init.zeros_(self.function_signal_proj.weight)
|
| nn.init.zeros_(self.function_signal_proj.bias)
|
| nn.init.zeros_(self.search_signal_proj.weight)
|
| nn.init.zeros_(self.search_signal_proj.bias)
|
| self.numeric_head = nn.Linear(self.hidden_size, 1)
|
| self.section_head = nn.Linear(self.hidden_size, len(SECTION_NAMES))
|
| if self.vision is not None and args.freeze_vision:
|
| self.freeze_vision(args.unfreeze_vision_last_ratio)
|
|
|
| def _vision_hidden_size(self) -> int:
|
| if self.vision is None:
|
| return self.hidden_size
|
| config = self.vision.config
|
| if hasattr(config, "vision_config"):
|
| return int(config.vision_config.hidden_size)
|
| return int(getattr(config, "hidden_size"))
|
|
|
| def freeze_vision(self, unfreeze_last_ratio: float) -> None:
|
| if self.vision is None:
|
| return
|
| for param in self.vision.parameters():
|
| param.requires_grad = False
|
| if unfreeze_last_ratio <= 0:
|
| return
|
| layers = None
|
| vision_model = getattr(self.vision, "vision_model", self.vision)
|
| encoder = getattr(vision_model, "encoder", None)
|
| if encoder is not None:
|
| layers = getattr(encoder, "layers", None) or getattr(encoder, "layer", None)
|
| if layers:
|
| keep = max(1, int(len(layers) * unfreeze_last_ratio))
|
| for layer in layers[-keep:]:
|
| for param in layer.parameters():
|
| param.requires_grad = True
|
|
|
| def use_native_context_forward(self) -> bool:
|
| return bool(getattr(self.args, "native_context_forward", False))
|
|
|
| def mean_embed(self, input_ids: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
|
| emb = self.decoder.get_input_embeddings()(input_ids)
|
| mask = attention_mask.unsqueeze(-1).float()
|
| return (emb * mask).sum(dim=-2) / mask.sum(dim=-2).clamp_min(1.0)
|
|
|
| def encode_vision(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
| bsz, crops, channels, height, width = pixel_values.shape
|
| flat = pixel_values.view(bsz * crops, channels, height, width)
|
| vision_model = getattr(self.vision, "vision_model", self.vision)
|
| vision_trainable = any(param.requires_grad for param in self.vision.parameters())
|
| with torch.set_grad_enabled(vision_trainable):
|
| try:
|
| out = vision_model(pixel_values=flat, interpolate_pos_encoding=True)
|
| except TypeError:
|
| out = vision_model(pixel_values=flat)
|
| tokens = out.last_hidden_state
|
| tokens = tokens.view(bsz, crops, tokens.size(1), tokens.size(2))
|
| return tokens
|
|
|
| def reduce_visual_tokens(self, visual: torch.Tensor) -> torch.Tensor:
|
| bsz, crops, num_tokens, hidden = visual.shape
|
| max_visual_tokens = int(getattr(self.args, "max_visual_tokens", 0) or 0)
|
| if max_visual_tokens <= 0 or crops * num_tokens <= max_visual_tokens:
|
| return self.visual_proj(visual.flatten(1, 2))
|
|
|
| per_crop = max(1, max_visual_tokens // max(1, crops))
|
| reduced_crops: List[torch.Tensor] = []
|
| for crop_idx in range(crops):
|
| crop_tokens = visual[:, crop_idx, :, :]
|
| grid = int(math.sqrt(num_tokens))
|
| cls_token = None
|
| patch_tokens = crop_tokens
|
| if grid * grid != num_tokens:
|
| maybe_grid = int(math.sqrt(num_tokens - 1))
|
| if maybe_grid * maybe_grid == num_tokens - 1:
|
| cls_token = crop_tokens[:, :1, :]
|
| patch_tokens = crop_tokens[:, 1:, :]
|
| grid = maybe_grid
|
| else:
|
| target = min(per_crop, num_tokens)
|
| indices = torch.linspace(0, num_tokens - 1, steps=target, device=visual.device).round().long()
|
| reduced_crops.append(crop_tokens.index_select(1, indices))
|
| continue
|
| cls_budget = 1 if cls_token is not None and per_crop > 1 else 0
|
| patch_budget = max(1, per_crop - cls_budget)
|
| out_grid = max(1, int(math.sqrt(patch_budget)))
|
| patch_tokens = patch_tokens.view(bsz, grid, grid, hidden).permute(0, 3, 1, 2)
|
| pooled = F.adaptive_avg_pool2d(patch_tokens, (out_grid, out_grid))
|
| pooled = pooled.permute(0, 2, 3, 1).reshape(bsz, out_grid * out_grid, hidden)
|
| if cls_budget:
|
| pooled = torch.cat([cls_token, pooled], dim=1)
|
| reduced_crops.append(pooled)
|
| return self.visual_proj(torch.cat(reduced_crops, dim=1))
|
|
|
| def roi_pool(self, full_tokens: torch.Tensor, bboxes: torch.Tensor) -> torch.Tensor:
|
| bsz, num_tokens, hidden = full_tokens.shape
|
| grid = int(math.sqrt(num_tokens))
|
| if grid * grid != num_tokens:
|
| patch_tokens = full_tokens
|
| grid = int(math.sqrt(num_tokens - 1))
|
| if grid * grid == num_tokens - 1:
|
| patch_tokens = full_tokens[:, 1:, :]
|
| else:
|
| pooled = full_tokens.mean(dim=1, keepdim=True).expand(-1, bboxes.size(1), -1)
|
| return pooled
|
| else:
|
| patch_tokens = full_tokens
|
| patch_tokens = patch_tokens.view(bsz, grid, grid, hidden)
|
| outputs = []
|
| for b in range(bsz):
|
| elem_vecs = []
|
| for bbox in bboxes[b]:
|
| x1, y1, x2, y2 = bbox.tolist()
|
| ix1 = max(0, min(grid - 1, int(x1 * grid)))
|
| iy1 = max(0, min(grid - 1, int(y1 * grid)))
|
| ix2 = max(ix1 + 1, min(grid, int(math.ceil(x2 * grid))))
|
| iy2 = max(iy1 + 1, min(grid, int(math.ceil(y2 * grid))))
|
| region = patch_tokens[b, iy1:iy2, ix1:ix2, :]
|
| elem_vecs.append(region.reshape(-1, hidden).mean(dim=0))
|
| outputs.append(torch.stack(elem_vecs, dim=0))
|
| return torch.stack(outputs, dim=0)
|
|
|
| def build_memory(self, batch: Dict[str, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| use_activation_checkpointing = self.training and bool(getattr(self.args, "activation_checkpointing", False))
|
| bsz = batch["element_mask"].size(0)
|
| if self.vision is not None:
|
| pixel_values = batch["pixel_values"]
|
| visual = self.encode_vision(pixel_values)
|
| full_visual = visual[:, 0, :, :]
|
| visual_tokens = self.reduce_visual_tokens(visual)
|
| else:
|
| full_visual = None
|
| visual_tokens = batch["element_mask"].new_zeros((bsz, 0, self.hidden_size), dtype=self.decoder.get_input_embeddings().weight.dtype)
|
|
|
| elem_ids = batch["element_input_ids"]
|
| elem_mask_tok = batch["element_attention_mask"]
|
| bsz, max_elements, elem_len = elem_ids.shape
|
| elem_text = self.mean_embed(elem_ids.view(bsz * max_elements, elem_len), elem_mask_tok.view(bsz * max_elements, elem_len))
|
| elem_text = self.elem_text_proj(elem_text.view(bsz, max_elements, -1))
|
| bbox = batch["bboxes"]
|
| wh = (bbox[..., 2:] - bbox[..., :2]).clamp_min(0)
|
| bbox_feat = torch.cat([bbox, wh], dim=-1)
|
| elem_tokens = (
|
| elem_text
|
| + self.type_emb(batch["type_ids"])
|
| + self.source_emb(batch["source_ids"])
|
| + self.loc_emb(batch["loc_ids"])
|
| + self.bbox_proj(bbox_feat)
|
| + batch["confs"] * 0.1
|
| + batch["action_flags"] * 0.1
|
| )
|
| if self.vision is not None and self.model_variant in {"full", "late_fusion"}:
|
| roi = self.roi_proj(self.roi_pool(full_visual, bbox))
|
| elem_tokens = elem_tokens + roi
|
| elem_key_padding = ~batch["element_mask"]
|
| if use_activation_checkpointing:
|
| elem_tokens = checkpoint(
|
| lambda tokens, padding: self.layout_encoder(tokens, src_key_padding_mask=padding),
|
| elem_tokens,
|
| elem_key_padding,
|
| use_reentrant=False,
|
| )
|
| else:
|
| elem_tokens = self.layout_encoder(elem_tokens, src_key_padding_mask=elem_key_padding)
|
| head_elem_tokens = elem_tokens
|
| decoder_elem_tokens = elem_tokens
|
| if bool(getattr(self.args, "function_signal_to_decoder", False)):
|
| function_signal = torch.sigmoid(self.ui_function_head(head_elem_tokens)).detach()
|
| function_signal = function_signal.masked_fill(elem_key_padding.unsqueeze(-1), 0.0)
|
| signal_scale = float(getattr(self.args, "function_signal_scale", 1.0) or 1.0)
|
| decoder_elem_tokens = decoder_elem_tokens + signal_scale * self.function_signal_proj(function_signal)
|
| if bool(getattr(self.args, "search_signal_to_decoder", False)):
|
| search_signal = torch.sigmoid(self.search_function_head(head_elem_tokens)).detach()
|
| search_signal = search_signal.masked_fill(elem_key_padding.unsqueeze(-1), 0.0)
|
| search_signal_scale = float(getattr(self.args, "search_signal_scale", 1.0) or 1.0)
|
| decoder_elem_tokens = decoder_elem_tokens + search_signal_scale * self.search_signal_proj(search_signal)
|
|
|
| context_mode = str(getattr(self.args, "context_mode", "mean") or "mean").lower()
|
| direct_context_tokens = None
|
| if context_mode in {"tokens_encoder", "tokens_direct_encoder"}:
|
| context_encoded = self.decoder.get_encoder()(input_ids=batch["context_input_ids"], attention_mask=batch["context_attention_mask"], return_dict=True)
|
| direct_context_tokens = context_encoded.last_hidden_state
|
| context_tokens = self.context_proj(context_encoded.last_hidden_state)
|
| context_padding = ~batch["context_attention_mask"].bool()
|
| elif context_mode in {"tokens", "tokens_direct"}:
|
| context_emb = self.decoder.get_input_embeddings()(batch["context_input_ids"])
|
| direct_context_tokens = context_emb
|
| context_tokens = self.context_proj(context_emb)
|
| context_padding = ~batch["context_attention_mask"].bool()
|
| else:
|
| context = self.mean_embed(batch["context_input_ids"], batch["context_attention_mask"])
|
| context_tokens = self.context_proj(context).unsqueeze(1)
|
| context_padding = torch.zeros(bsz, 1, dtype=torch.bool, device=context_tokens.device)
|
|
|
| context_len = context_tokens.size(1)
|
| visual_start = -1
|
| visual_len = 0
|
| elem_start = -1
|
| elem_len_for_fusion = 0
|
| if self.model_variant == "annotation_only":
|
| elem_start = context_len
|
| elem_len_for_fusion = decoder_elem_tokens.size(1)
|
| fusion_input = torch.cat([context_tokens, decoder_elem_tokens], dim=1)
|
| fusion_padding = torch.cat([context_padding, elem_key_padding], dim=1)
|
| elif self.model_variant == "image_only":
|
| if self.vision is None:
|
| raise ValueError("image_only requires vision; do not set disable_vision=true")
|
| visual_start = 0
|
| visual_len = visual_tokens.size(1)
|
| fusion_input = visual_tokens
|
| fusion_padding = torch.zeros(bsz, fusion_input.size(1), dtype=torch.bool, device=fusion_input.device)
|
| elif self.model_variant == "late_fusion":
|
| visual_summary = visual_tokens.mean(dim=1, keepdim=True)
|
| visual_start = context_len
|
| visual_len = 1
|
| elem_start = context_len + 1
|
| elem_len_for_fusion = decoder_elem_tokens.size(1)
|
| fusion_input = torch.cat([context_tokens, visual_summary, decoder_elem_tokens], dim=1)
|
| fusion_padding = torch.cat(
|
| [context_padding, torch.zeros(bsz, 1, dtype=torch.bool, device=elem_key_padding.device), elem_key_padding],
|
| dim=1,
|
| )
|
| else:
|
| visual_start = context_len
|
| visual_len = visual_tokens.size(1)
|
| elem_start = context_len + visual_len
|
| elem_len_for_fusion = decoder_elem_tokens.size(1)
|
| fusion_input = torch.cat([context_tokens, visual_tokens, decoder_elem_tokens], dim=1)
|
| fusion_padding = torch.cat(
|
| [
|
| context_padding,
|
| torch.zeros(bsz, visual_tokens.size(1), dtype=torch.bool, device=elem_key_padding.device),
|
| elem_key_padding,
|
| ],
|
| dim=1,
|
| )
|
| if use_activation_checkpointing:
|
| fused = checkpoint(
|
| lambda tokens, padding: self.fusion(tokens, src_key_padding_mask=padding),
|
| fusion_input,
|
| fusion_padding,
|
| use_reentrant=False,
|
| )
|
| else:
|
| fused = self.fusion(fusion_input, src_key_padding_mask=fusion_padding)
|
| if elem_start >= 0 and elem_len_for_fusion > 0:
|
| head_elem_tokens = fused[:, elem_start : elem_start + elem_len_for_fusion, :]
|
| if use_activation_checkpointing:
|
| pooled = checkpoint(
|
| lambda tokens, padding: self.pooler(tokens, key_padding_mask=padding),
|
| fused,
|
| fusion_padding,
|
| use_reentrant=False,
|
| )
|
| else:
|
| pooled = self.pooler(fused, key_padding_mask=fusion_padding)
|
| pooled_padding = torch.zeros(bsz, pooled.size(1), dtype=torch.bool, device=pooled.device)
|
| memory_parts = []
|
| memory_padding_parts = []
|
| if context_mode in {"tokens_direct", "tokens_direct_encoder"} and self.model_variant != "image_only":
|
| if bool(getattr(self.args, "direct_context_passthrough", False)) and direct_context_tokens is not None:
|
| memory_parts.append(direct_context_tokens)
|
| else:
|
| memory_parts.append(fused[:, :context_len, :])
|
| memory_padding_parts.append(context_padding)
|
| if bool(getattr(self.args, "direct_visual_tokens", False)) and visual_start >= 0 and visual_len > 0:
|
| visual_scale = float(getattr(self.args, "visual_memory_scale", 1.0) or 1.0)
|
| memory_parts.append(fused[:, visual_start : visual_start + visual_len, :] * visual_scale)
|
| memory_padding_parts.append(torch.zeros(bsz, visual_len, dtype=torch.bool, device=pooled.device))
|
| if bool(getattr(self.args, "direct_element_tokens", False)) and elem_start >= 0 and elem_len_for_fusion > 0:
|
| element_scale = float(getattr(self.args, "element_memory_scale", 1.0) or 1.0)
|
| memory_parts.append(fused[:, elem_start : elem_start + elem_len_for_fusion, :] * element_scale)
|
| memory_padding_parts.append(elem_key_padding)
|
| if bool(getattr(self.args, "include_pooled_memory", True)): |
| pooled_scale = float(getattr(self.args, "pooled_memory_scale", 1.0) or 1.0) |
| memory_parts.append(pooled * pooled_scale) |
| memory_padding_parts.append(pooled_padding) |
| if not memory_parts: |
| pooled_scale = float(getattr(self.args, "pooled_memory_scale", 1.0) or 1.0) |
| memory_parts.append(pooled * pooled_scale) |
| memory_padding_parts.append(pooled_padding) |
| memory = torch.cat(memory_parts, dim=1)
|
| memory_scale = float(getattr(self.args, "decoder_memory_scale", 1.0) or 1.0)
|
| if memory_scale != 1.0:
|
| memory = memory * memory_scale
|
| memory_padding = torch.cat(memory_padding_parts, dim=1)
|
| memory_attention_mask = (~memory_padding).long()
|
| return memory, memory_attention_mask, head_elem_tokens, elem_key_padding
|
|
|
| def forward(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
|
| if self.use_native_context_forward():
|
| labels = batch["labels"]
|
| decoder_out = self.decoder(
|
| input_ids=batch["context_input_ids"],
|
| attention_mask=batch["context_attention_mask"],
|
| labels=labels,
|
| use_cache=False,
|
| )
|
| gen_loss = decoder_out.loss
|
| element_shape = batch["element_mask"].shape
|
| empty_logits = labels.new_full(element_shape, -20.0, dtype=torch.float32)
|
| section_logits = labels.new_zeros((labels.size(0), int(batch["section_labels"].shape[-1])), dtype=torch.float32)
|
| zero_loss = gen_loss.detach().new_zeros(())
|
| return {
|
| "loss": gen_loss,
|
| "generation_loss": gen_loss.detach(),
|
| "evidence_loss": zero_loss,
|
| "ui_function_loss": zero_loss,
|
| "search_function_loss": zero_loss,
|
| "section_loss": zero_loss,
|
| "numeric_loss": zero_loss,
|
| "evidence_logits": empty_logits,
|
| "ui_function_logits": empty_logits,
|
| "search_function_logits": empty_logits,
|
| "section_logits": section_logits,
|
| }
|
| memory, memory_attention_mask, elem_tokens, elem_key_padding = self.build_memory(batch)
|
| encoder_outputs = BaseModelOutput(last_hidden_state=memory)
|
| labels = batch["labels"]
|
| if hasattr(self.decoder, "prepare_decoder_input_ids_from_labels"):
|
| decoder_input_ids = self.decoder.prepare_decoder_input_ids_from_labels(labels=labels)
|
| elif hasattr(self.decoder, "_shift_right"):
|
| decoder_input_ids = self.decoder._shift_right(labels)
|
| else:
|
| pad_token_id = int(getattr(self.decoder.config, "pad_token_id", 0) or 0)
|
| start_token_id = int(getattr(self.decoder.config, "decoder_start_token_id", pad_token_id) or pad_token_id)
|
| decoder_input_ids = labels.new_full(labels.shape, pad_token_id)
|
| decoder_input_ids[:, 0] = start_token_id
|
| decoder_input_ids[:, 1:] = labels[:, :-1].masked_fill(labels[:, :-1] == -100, pad_token_id)
|
| decoder_out = self.decoder(
|
| encoder_outputs=encoder_outputs,
|
| attention_mask=memory_attention_mask,
|
| decoder_input_ids=decoder_input_ids,
|
| use_cache=False,
|
| )
|
| logits = decoder_out.logits
|
| flat_logits = logits.reshape(-1, logits.size(-1))
|
| flat_labels = labels.reshape(-1)
|
| valid_token_count = flat_labels.ne(-100).sum().clamp_min(1)
|
| chunk_size = max(1, int(getattr(self.args, "generation_loss_chunk_size", 32) or 32))
|
| gen_loss_sum = logits.new_zeros((), dtype=torch.float32)
|
| for start in range(0, flat_logits.size(0), chunk_size):
|
| end = min(start + chunk_size, flat_logits.size(0))
|
| gen_loss_sum = gen_loss_sum + F.cross_entropy(
|
| flat_logits[start:end].float(),
|
| flat_labels[start:end],
|
| ignore_index=-100,
|
| reduction="sum",
|
| )
|
| gen_loss = gen_loss_sum / valid_token_count
|
| evidence_logits = self.evidence_head(elem_tokens).squeeze(-1)
|
| ui_function_logits = self.ui_function_head(elem_tokens).squeeze(-1)
|
| search_function_logits = self.search_function_head(elem_tokens).squeeze(-1)
|
| numeric_logits = self.numeric_head(elem_tokens).squeeze(-1)
|
| valid = (~elem_key_padding).float()
|
| evidence_loss = F.binary_cross_entropy_with_logits(
|
| evidence_logits, batch["evidence_labels"], reduction="none"
|
| )
|
| evidence_loss = (evidence_loss * valid).sum() / valid.sum().clamp_min(1.0)
|
| ui_function_loss = F.binary_cross_entropy_with_logits(
|
| ui_function_logits, batch["ui_function_labels"], reduction="none"
|
| )
|
| ui_function_loss = (ui_function_loss * valid).sum() / valid.sum().clamp_min(1.0)
|
| search_pos_weight = float(getattr(self.args, "search_function_pos_weight", 1.0) or 1.0)
|
| search_loss_kwargs: Dict[str, Any] = {"reduction": "none"}
|
| if search_pos_weight != 1.0:
|
| search_loss_kwargs["pos_weight"] = torch.tensor(search_pos_weight, device=search_function_logits.device)
|
| search_function_loss = F.binary_cross_entropy_with_logits(
|
| search_function_logits, batch["search_function_labels"], **search_loss_kwargs
|
| )
|
| search_function_loss = (search_function_loss * valid).sum() / valid.sum().clamp_min(1.0)
|
| numeric_loss = F.binary_cross_entropy_with_logits(
|
| numeric_logits, batch["numeric_labels"], reduction="none"
|
| )
|
| numeric_loss = (numeric_loss * valid).sum() / valid.sum().clamp_min(1.0)
|
| memory_mask = memory_attention_mask.unsqueeze(-1).float()
|
| memory_summary = (memory * memory_mask).sum(dim=1) / memory_mask.sum(dim=1).clamp_min(1.0)
|
| section_logits = self.section_head(memory_summary)
|
| section_loss = F.binary_cross_entropy_with_logits(section_logits, batch["section_labels"])
|
| total = (
|
| gen_loss
|
| + self.args.evidence_loss_weight * evidence_loss
|
| + float(getattr(self.args, "ui_function_loss_weight", 0.0) or 0.0) * ui_function_loss
|
| + float(getattr(self.args, "search_function_loss_weight", 0.0) or 0.0) * search_function_loss
|
| + self.args.section_loss_weight * section_loss
|
| + self.args.numeric_loss_weight * numeric_loss
|
| )
|
| return {
|
| "loss": total,
|
| "generation_loss": gen_loss.detach(),
|
| "evidence_loss": evidence_loss.detach(),
|
| "ui_function_loss": ui_function_loss.detach(),
|
| "search_function_loss": search_function_loss.detach(),
|
| "section_loss": section_loss.detach(),
|
| "numeric_loss": numeric_loss.detach(),
|
| "evidence_logits": evidence_logits.detach(),
|
| "ui_function_logits": ui_function_logits.detach(),
|
| "search_function_logits": search_function_logits.detach(),
|
| "section_logits": section_logits.detach(),
|
| }
|
|
|
| @torch.no_grad()
|
| def generate_text(self, batch: Dict[str, torch.Tensor], tokenizer, num_beams: int = 4, max_new_tokens: int = 384) -> List[str]:
|
|
|
|
|
| amp_dtype_name = str(getattr(self.args, "amp_dtype", "auto") or "auto").lower()
|
| if amp_dtype_name == "auto":
|
| amp_dtype_name = "fp16" if bool(getattr(self.args, "fp16", False)) else "fp32"
|
| device_is_cuda = batch["pixel_values"].is_cuda if torch.is_tensor(batch.get("pixel_values")) else False
|
| amp_enabled = device_is_cuda and amp_dtype_name in {"fp16", "bf16"}
|
| amp_dtype = torch.float16 if amp_dtype_name == "fp16" else torch.bfloat16
|
| with torch.amp.autocast("cuda", enabled=amp_enabled, dtype=amp_dtype):
|
| generation_kwargs = build_generation_kwargs(self.args, tokenizer)
|
| if self.use_native_context_forward():
|
| generated = self.decoder.generate(
|
| input_ids=batch["context_input_ids"],
|
| attention_mask=batch["context_attention_mask"],
|
| num_beams=num_beams,
|
| max_new_tokens=max_new_tokens,
|
| **generation_kwargs,
|
| )
|
| else:
|
| memory, memory_attention_mask, _, _ = self.build_memory(batch)
|
| encoder_outputs = BaseModelOutput(last_hidden_state=memory)
|
| generated = self.decoder.generate(
|
| encoder_outputs=encoder_outputs,
|
| attention_mask=memory_attention_mask,
|
| num_beams=num_beams,
|
| max_new_tokens=max_new_tokens,
|
| **generation_kwargs,
|
| )
|
| return tokenizer.batch_decode(generated, skip_special_tokens=True)
|
|
|
|
|
| def move_batch(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
|
| out = {}
|
| for key, value in batch.items():
|
| if key == "rows":
|
| out[key] = value
|
| elif torch.is_tensor(value):
|
| out[key] = value.to(device, non_blocking=True)
|
| else:
|
| out[key] = value
|
| return out
|
|
|
|
|
| def trainable_parameter_stats(model: nn.Module) -> Dict[str, Any]:
|
| stats = {
|
| "trainable_params_total": 0,
|
| "trainable_params_vision": 0,
|
| "trainable_params_decoder": 0,
|
| "trainable_params_other": 0,
|
| }
|
| for name, param in model.named_parameters():
|
| if not param.requires_grad:
|
| continue
|
| count = int(param.numel())
|
| stats["trainable_params_total"] += count
|
| if name.startswith("vision."):
|
| stats["trainable_params_vision"] += count
|
| elif name.startswith("decoder."):
|
| stats["trainable_params_decoder"] += count
|
| else:
|
| stats["trainable_params_other"] += count
|
| stats["vision_trainable"] = stats["trainable_params_vision"] > 0
|
| return stats
|
|
|
|
|
| def reduce_parallel_losses(output: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
| for key in [
|
| "loss",
|
| "generation_loss",
|
| "evidence_loss",
|
| "ui_function_loss",
|
| "search_function_loss",
|
| "section_loss",
|
| "numeric_loss",
|
| ]:
|
| value = output.get(key)
|
| if torch.is_tensor(value) and value.ndim > 0:
|
| output[key] = value.mean()
|
| return output
|
|
|
|
|
| def build_optimizer(model: nn.Module, args: argparse.Namespace) -> torch.optim.Optimizer:
|
| decoder_params = []
|
| vision_params = []
|
| ui_function_params = []
|
| other_params = []
|
| lr_ui_function_head = float(getattr(args, "lr_ui_function_head", 0.0) or 0.0)
|
| for name, param in model.named_parameters():
|
| if not param.requires_grad:
|
| continue
|
| if lr_ui_function_head > 0 and (
|
| name.startswith("ui_function_head.")
|
| or name.startswith("search_function_head.")
|
| or name.startswith("function_signal_proj.")
|
| or name.startswith("search_signal_proj.")
|
| ):
|
| ui_function_params.append(param)
|
| elif name.startswith("decoder."):
|
| decoder_params.append(param)
|
| elif name.startswith("vision."):
|
| vision_params.append(param)
|
| else:
|
| other_params.append(param)
|
| groups = [
|
| {"params": other_params, "lr": args.lr_new},
|
| {"params": vision_params, "lr": args.lr_fusion},
|
| {"params": decoder_params, "lr": args.lr_decoder},
|
| ]
|
| if ui_function_params:
|
| groups.append({"params": ui_function_params, "lr": lr_ui_function_head})
|
| optimizer_name = str(getattr(args, "optimizer_name", "adamw") or "adamw").lower()
|
| if optimizer_name == "adamw":
|
| return torch.optim.AdamW(groups, weight_decay=args.weight_decay)
|
| if optimizer_name == "adafactor":
|
| return Adafactor(
|
| groups,
|
| scale_parameter=False,
|
| relative_step=False,
|
| warmup_init=False,
|
| weight_decay=args.weight_decay,
|
| )
|
| raise ValueError("optimizer_name must be one of: adamw, adafactor")
|
|
|
|
|
| def clip_gradients(model: nn.Module, optimizer: torch.optim.Optimizer, args: argparse.Namespace) -> torch.Tensor:
|
| max_grad_norm = float(getattr(args, "max_grad_norm", 1.0) or 0.0)
|
| first_param = next(model.parameters())
|
| if max_grad_norm <= 0:
|
| return first_param.detach().new_tensor(0.0)
|
| strategy = str(getattr(args, "grad_clip_strategy", "global") or "global").lower()
|
| if strategy == "per_group":
|
| norms: List[torch.Tensor] = []
|
| for group in optimizer.param_groups:
|
| params = [param for param in group["params"] if param.grad is not None]
|
| if not params:
|
| continue
|
| group_norm = torch.nn.utils.clip_grad_norm_(params, max_grad_norm, foreach=False)
|
| norms.append(group_norm.detach().to(device=first_param.device, dtype=torch.float32))
|
| if not norms:
|
| return first_param.detach().new_tensor(0.0)
|
| return torch.stack(norms).max()
|
| if strategy != "global":
|
| raise ValueError("grad_clip_strategy must be one of: global, per_group")
|
| return torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm, foreach=False)
|
|
|
|
|
| def get_lr_schedule(optimizer, total_steps: int, warmup_ratio: float, scheduler_type: str = "linear"): |
| warmup = int(total_steps * warmup_ratio) |
| scheduler_type = str(scheduler_type or "linear").lower() |
|
|
| def lr_lambda(step: int) -> float: |
| if step < warmup: |
| return max(1e-8, step / max(1, warmup)) |
| progress = (step - warmup) / max(1, total_steps - warmup) |
| if scheduler_type == "cosine": |
| return max(0.0, 0.5 * (1.0 + math.cos(math.pi * min(1.0, progress)))) |
| return max(0.0, 1.0 - progress) |
|
|
| return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
| def save_checkpoint(
|
| output_dir: Path,
|
| name: str,
|
| model: nn.Module,
|
| tokenizer,
|
| image_processor,
|
| args: argparse.Namespace,
|
| metrics: Dict[str, Any],
|
| ) -> None:
|
| ckpt_dir = output_dir / name
|
| ckpt_dir.mkdir(parents=True, exist_ok=True)
|
| raw_model = unwrap_parallel_model(model)
|
| torch.save(raw_model.state_dict(), ckpt_dir / "pytorch_model.bin")
|
| write_json(ckpt_dir / "rich_config.json", vars(args))
|
| write_json(ckpt_dir / "metrics.json", metrics)
|
| tokenizer_dir = ckpt_dir / "decoder_tokenizer" |
| tokenizer.save_pretrained(tokenizer_dir) |
| vocab_file = getattr(tokenizer, "vocab_file", None) |
| if vocab_file and Path(vocab_file).exists() and not (tokenizer_dir / "spiece.model").exists(): |
| shutil.copyfile(vocab_file, tokenizer_dir / "spiece.model") |
| decoder_model_dir = Path(str(getattr(args, "decoder_model", "") or "")) |
| if decoder_model_dir.is_dir(): |
| for tokenizer_asset in ("spiece.model", "config.json"): |
| source_asset = decoder_model_dir / tokenizer_asset |
| target_asset = tokenizer_dir / tokenizer_asset |
| if source_asset.exists() and not target_asset.exists(): |
| shutil.copyfile(source_asset, target_asset) |
| image_processor.save_pretrained(ckpt_dir / "image_processor")
|
|
|
|
|
| def resize_checkpoint_tensor(value: torch.Tensor, target: torch.Tensor) -> Optional[torch.Tensor]:
|
| if value.ndim != target.ndim or value.ndim == 0:
|
| return None
|
| resized = value.detach().cpu()
|
| for dim, target_size in enumerate(target.shape):
|
| current_size = resized.shape[dim]
|
| if current_size == target_size:
|
| continue
|
| if current_size <= 0 or target_size <= 0:
|
| return None
|
| if current_size > target_size:
|
| resized = resized.narrow(dim, 0, target_size)
|
| continue
|
| repeats = [1] * resized.ndim
|
| repeats[dim] = math.ceil(target_size / current_size)
|
| resized = resized.repeat(*repeats).narrow(dim, 0, target_size)
|
| return resized.to(dtype=target.dtype)
|
|
|
|
|
| def load_compatible_model_state(
|
| model: nn.Module,
|
| state: Dict[str, torch.Tensor],
|
| allow_missing_prefixes: Tuple[str, ...] = (),
|
| resize_mismatched_non_decoder: bool = False,
|
| ) -> Tuple[List[str], List[str], List[str], List[str]]:
|
| model_state = model.state_dict()
|
| compatible_state: Dict[str, torch.Tensor] = {}
|
| resized_incompatible: List[str] = []
|
| skipped_incompatible: List[str] = []
|
| skipped_unexpected: List[str] = []
|
| for key, value in state.items():
|
| if key not in model_state:
|
| skipped_unexpected.append(key)
|
| continue
|
| if tuple(model_state[key].shape) != tuple(value.shape):
|
| if resize_mismatched_non_decoder and not key.startswith("decoder."):
|
| resized_value = resize_checkpoint_tensor(value, model_state[key])
|
| if resized_value is not None and tuple(resized_value.shape) == tuple(model_state[key].shape):
|
| compatible_state[key] = resized_value
|
| resized_incompatible.append(key)
|
| continue
|
| skipped_incompatible.append(key)
|
| continue
|
| compatible_state[key] = value
|
|
|
| load_result = model.load_state_dict(compatible_state, strict=False)
|
| allowed_missing = [
|
| key
|
| for key in load_result.missing_keys
|
| if key.startswith("ui_function_head.")
|
| or key.startswith("search_function_head.")
|
| or key.startswith("function_signal_proj.")
|
| or key.startswith("search_signal_proj.")
|
| or key in skipped_incompatible
|
| or key.startswith(allow_missing_prefixes)
|
| ]
|
| bad_missing = [key for key in load_result.missing_keys if key not in allowed_missing]
|
| if bad_missing or load_result.unexpected_keys or skipped_unexpected:
|
| raise RuntimeError(
|
| f"Checkpoint mismatch. missing={bad_missing}, unexpected={list(load_result.unexpected_keys) + skipped_unexpected}"
|
| )
|
| return list(load_result.missing_keys), list(load_result.unexpected_keys), skipped_incompatible, resized_incompatible
|
|
|
|
|
| def load_rich_checkpoint(checkpoint: str, device: torch.device) -> Tuple[RichGroundedModel, Any, Any, argparse.Namespace]:
|
| ckpt_dir = Path(checkpoint)
|
| config = json.loads((ckpt_dir / "rich_config.json").read_text(encoding="utf-8"))
|
| merged_config = dict(DEFAULT_CONFIG)
|
| merged_config.update(config)
|
| args = argparse.Namespace(**merged_config)
|
| tokenizer = load_seq2seq_tokenizer(str(ckpt_dir / "decoder_tokenizer"), str(args.decoder_model)) |
| image_processor = AutoImageProcessor.from_pretrained(ckpt_dir / "image_processor")
|
| model = RichGroundedModel(args)
|
| state = torch.load(ckpt_dir / "pytorch_model.bin", map_location="cpu")
|
| load_compatible_model_state(model, state)
|
| model.to(device)
|
| model.eval()
|
| return model, tokenizer, image_processor, args
|
|
|
|
|
| def get_eval_max_new_tokens(args: argparse.Namespace) -> int:
|
| return int(getattr(args, "eval_max_new_tokens", getattr(args, "max_target_tokens", 384)))
|
|
|
|
|
| def build_generation_kwargs(args: argparse.Namespace, tokenizer) -> Dict[str, Any]:
|
| kwargs: Dict[str, Any] = {}
|
| no_repeat_ngram_size = int(getattr(args, "generation_no_repeat_ngram_size", 0) or 0)
|
| if no_repeat_ngram_size > 0:
|
| kwargs["no_repeat_ngram_size"] = no_repeat_ngram_size
|
| repetition_penalty = float(getattr(args, "generation_repetition_penalty", 1.0) or 1.0)
|
| if repetition_penalty != 1.0:
|
| kwargs["repetition_penalty"] = repetition_penalty
|
| min_new_tokens = int(getattr(args, "generation_min_new_tokens", 0) or 0)
|
| if min_new_tokens > 0:
|
| kwargs["min_new_tokens"] = min_new_tokens
|
| bad_words_ids = [] |
| if bool(getattr(args, "generation_block_extra_ids", False)): |
| vocab_len = len(tokenizer) |
| for token_id in range(max(0, vocab_len - 256), vocab_len): |
| token = tokenizer.convert_ids_to_tokens(token_id) |
| if isinstance(token, str) and "<extra_id_" in token: |
| bad_words_ids.append([token_id]) |
| if bool(getattr(args, "generation_block_title_prefix", False)): |
| for phrase in ("Title", "title", "Title:", "title:", "Title :", "title :"): |
| token_ids = tokenizer.encode(phrase, add_special_tokens=False) |
| if token_ids: |
| bad_words_ids.append(token_ids) |
| if bad_words_ids: |
| deduped_bad_words = [] |
| seen_bad_words = set() |
| for token_ids in bad_words_ids: |
| key = tuple(int(token_id) for token_id in token_ids) |
| if key in seen_bad_words: |
| continue |
| seen_bad_words.add(key) |
| deduped_bad_words.append(list(key)) |
| kwargs["bad_words_ids"] = deduped_bad_words |
| if bool(getattr(args, "generation_force_json_start", False)):
|
| json_start_ids = tokenizer.encode("{", add_special_tokens=False)
|
| if len(json_start_ids) == 1:
|
| kwargs["forced_bos_token_id"] = json_start_ids[0]
|
| return kwargs
|
|
|
|
|
| @torch.no_grad()
|
| def evaluate(
|
| model: nn.Module,
|
| loader: DataLoader,
|
| tokenizer,
|
| device: torch.device,
|
| args: argparse.Namespace,
|
| rank: int = 0,
|
| ) -> Dict[str, Any]:
|
| raw_model = unwrap_parallel_model(model)
|
| raw_model.eval()
|
| losses: List[float] = []
|
| rouges: List[float] = []
|
| json_valid = 0
|
| generation_json_valid = 0
|
| generation_json_strict_valid = 0
|
| generation_json_repair_applied = 0
|
| total = 0
|
| evidence_precisions: List[float] = []
|
| ui_function_tp = 0
|
| ui_function_fp = 0
|
| ui_function_fn = 0
|
| ui_function_pred_positive = 0
|
| ui_function_ref_positive = 0
|
| ui_function_valid_elements = 0
|
| generation_char_lengths: List[int] = []
|
| pred_summary_char_lengths: List[int] = []
|
| ref_summary_char_lengths: List[int] = []
|
| empty_summary_count = 0 |
| extra_id_count = 0 |
| title_prefix_count = 0 |
| natural_title_prefix_stripped_count = 0 |
| whitespace_only_count = 0 |
| search_function_tp = 0
|
| search_function_fp = 0
|
| search_function_fn = 0
|
| search_function_pred_positive = 0
|
| search_function_ref_positive = 0
|
| search_tp = 0
|
| search_fp = 0
|
| search_fn = 0
|
| pred_search_count = 0
|
| ref_search_count = 0
|
| pred_function_count = 0
|
| ref_function_count = 0
|
| pred_bare_search_count = 0
|
| predictions: List[Dict[str, Any]] = []
|
| eval_max_new_tokens = get_eval_max_new_tokens(args)
|
| context_summary_repair = bool(getattr(args, "context_summary_repair", False))
|
| structured_mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower()
|
| target_schema = str(getattr(args, "target_schema", "zh") or "zh")
|
| summary_output_mode = target_schema_is_summary(target_schema)
|
| natural_output_mode = target_schema_is_natural_text(target_schema)
|
| function_metric_threshold = 0.5
|
| search_metric_threshold = 0.5
|
| if structured_mode == "heads":
|
| function_metric_threshold = float(getattr(args, "structured_function_threshold", 0.5) or 0.5)
|
| search_metric_threshold = float(getattr(args, "structured_search_threshold", function_metric_threshold) or function_metric_threshold)
|
| repair_count = 0
|
|
|
|
|
|
|
| template_exact_match = 0
|
| template_app_in_pred = 0
|
| template_instruction_in_pred = 0
|
| for batch in tqdm(loader, desc="valid", disable=not is_main(rank)):
|
| rows = batch["rows"]
|
| batch = move_batch(batch, device)
|
| out = raw_model(**{k: v for k, v in batch.items() if k != "rows"})
|
| losses.append(float(out["loss"].detach().cpu()))
|
| valid_elements = batch["element_mask"].bool()
|
| function_labels = batch["ui_function_labels"] > 0.5
|
| function_preds = torch.sigmoid(out["ui_function_logits"]) >= function_metric_threshold
|
| search_function_labels = batch["search_function_labels"] > 0.5
|
| search_function_preds = torch.sigmoid(out["search_function_logits"]) >= search_metric_threshold
|
| evidence_scores_batch = torch.sigmoid(out["evidence_logits"]).detach().cpu()
|
| function_scores_batch = torch.sigmoid(out["ui_function_logits"]).detach().cpu()
|
| search_scores_batch = torch.sigmoid(out["search_function_logits"]).detach().cpu()
|
| ui_function_tp += int((function_preds & function_labels & valid_elements).sum().detach().cpu())
|
| ui_function_fp += int((function_preds & ~function_labels & valid_elements).sum().detach().cpu())
|
| ui_function_fn += int((~function_preds & function_labels & valid_elements).sum().detach().cpu())
|
| ui_function_pred_positive += int((function_preds & valid_elements).sum().detach().cpu())
|
| ui_function_ref_positive += int((function_labels & valid_elements).sum().detach().cpu())
|
| search_function_tp += int((search_function_preds & search_function_labels & valid_elements).sum().detach().cpu())
|
| search_function_fp += int((search_function_preds & ~search_function_labels & valid_elements).sum().detach().cpu())
|
| search_function_fn += int((~search_function_preds & search_function_labels & valid_elements).sum().detach().cpu())
|
| search_function_pred_positive += int((search_function_preds & valid_elements).sum().detach().cpu())
|
| search_function_ref_positive += int((search_function_labels & valid_elements).sum().detach().cpu())
|
| ui_function_valid_elements += int(valid_elements.sum().detach().cpu())
|
| texts = raw_model.generate_text(batch, tokenizer, num_beams=args.num_beams, max_new_tokens=eval_max_new_tokens)
|
| for row_idx, (row, text) in enumerate(zip(rows, texts)):
|
| generation_char_lengths.append(len(text))
|
| if text and not text.strip():
|
| whitespace_only_count += 1
|
| if "<extra_id_" in text:
|
| extra_id_count += 1
|
| if has_natural_title_prefix(text): |
| title_prefix_count += 1 |
| if summary_output_mode:
|
| pred_obj = prediction_from_summary(row, text)
|
| ok = True
|
| elif natural_output_mode: |
| parse_text, stripped_title_prefix = strip_natural_title_prefix(text) |
| natural_title_prefix_stripped_count += int(stripped_title_prefix) |
| pred_obj = natural_prediction_from_text(parse_text) |
| ok = bool(extract_summary(pred_obj)) |
| generation_json_valid += int(ok) |
| else:
|
| pred_obj, ok, repair_applied, strict_ok = safe_json_loads_with_repair(text)
|
| generation_json_valid += int(ok)
|
| generation_json_strict_valid += int(strict_ok)
|
| generation_json_repair_applied += int(repair_applied)
|
| if context_summary_repair and not summary_output_mode:
|
| pred_obj, repaired = repair_prediction_with_context(row, pred_obj)
|
| ok = True
|
| repair_count += int(repaired)
|
| pred_obj = apply_structured_function_predictions(
|
| row,
|
| pred_obj,
|
| function_scores_batch[row_idx],
|
| search_scores_batch[row_idx],
|
| args,
|
| )
|
| pred_obj = apply_structured_evidence_predictions(row, pred_obj, evidence_scores_batch[row_idx], args)
|
| ref_target = row.get("target") or {}
|
| if bool(getattr(args, "canonicalize_targets", False)):
|
| ref_target = canonicalize_target_with_context(
|
| row,
|
| ref_target,
|
| drop_bare_search_functions=bool(getattr(args, "drop_bare_search_functions", False)),
|
| )
|
| ref_obj = json.loads(target_to_text(ref_target, "zh"))
|
| pred_has_search = has_search_function(pred_obj)
|
| ref_has_search = has_search_function(ref_obj)
|
| search_tp += int(pred_has_search and ref_has_search)
|
| search_fp += int(pred_has_search and not ref_has_search)
|
| search_fn += int((not pred_has_search) and ref_has_search)
|
| pred_search_count += count_search_functions(pred_obj)
|
| ref_search_count += count_search_functions(ref_obj)
|
| pred_function_count += len(extract_function_entries(pred_obj))
|
| ref_function_count += len(extract_function_entries(ref_obj))
|
| pred_bare_search_count += count_bare_search_functions(pred_obj)
|
| pred_summary = extract_summary(pred_obj)
|
| ref_summary = extract_summary(ref_obj)
|
| pred_summary_char_lengths.append(len(pred_summary))
|
| ref_summary_char_lengths.append(len(ref_summary))
|
| if not pred_summary:
|
| empty_summary_count += 1
|
| rouges.append(rouge_l_char(pred_summary, ref_summary))
|
|
|
| template_summary = build_context_summary(row)
|
| if pred_summary and pred_summary == template_summary:
|
| template_exact_match += 1
|
| row_app = safe_text(row.get("app"))
|
| row_instruction = safe_text(row.get("instruction"))
|
| if row_app and pred_summary and row_app in pred_summary:
|
| template_app_in_pred += 1
|
| if row_instruction and pred_summary and row_instruction in pred_summary:
|
| template_instruction_in_pred += 1
|
| json_valid += int(ok)
|
| total += 1
|
| pred_evidence = set(extract_evidence_ids(pred_obj))
|
| ref_evidence = set(ref_target.get("key_ui_clues", []) or row.get("weak_evidence_ids", []))
|
| if pred_evidence:
|
| evidence_precisions.append(len(pred_evidence & ref_evidence) / len(pred_evidence))
|
| else:
|
| evidence_precisions.append(0.0)
|
| if len(predictions) < 50:
|
| predictions.append(
|
| {
|
| "screen_id": row.get("screen_id"),
|
| "prediction_raw": text,
|
| "prediction": pred_obj,
|
| "reference": ref_obj,
|
| }
|
| )
|
| metrics = { |
| "loss": float(np.mean(losses)) if losses else 0.0,
|
| "rouge_l_char": float(np.mean(rouges)) if rouges else 0.0,
|
| "json_valid_rate": json_valid / max(1, total),
|
| "generation_json_valid_rate": generation_json_valid / max(1, total),
|
| "generation_json_strict_valid_rate": generation_json_strict_valid / max(1, total),
|
| "generation_json_repair_rate": generation_json_repair_applied / max(1, total),
|
| "generation_char_len_mean": float(np.mean(generation_char_lengths)) if generation_char_lengths else 0.0,
|
| "generation_char_len_max": int(max(generation_char_lengths)) if generation_char_lengths else 0,
|
| "pred_summary_char_len_mean": float(np.mean(pred_summary_char_lengths)) if pred_summary_char_lengths else 0.0,
|
| "ref_summary_char_len_mean": float(np.mean(ref_summary_char_lengths)) if ref_summary_char_lengths else 0.0,
|
| "empty_summary_rate": empty_summary_count / max(1, total),
|
| "whitespace_only_rate": whitespace_only_count / max(1, total), |
| "extra_id_rate": extra_id_count / max(1, total), |
| "title_prefix_rate": title_prefix_count / max(1, total), |
| "natural_title_prefix_stripped_rate": natural_title_prefix_stripped_count / max(1, total), |
| "evidence_precision": float(np.mean(evidence_precisions)) if evidence_precisions else 0.0, |
| "ui_function_precision": ui_function_tp / max(1, ui_function_tp + ui_function_fp),
|
| "ui_function_recall": ui_function_tp / max(1, ui_function_tp + ui_function_fn),
|
| "ui_function_f1": (2 * ui_function_tp) / max(1, 2 * ui_function_tp + ui_function_fp + ui_function_fn),
|
| "ui_function_pred_positive_rate": ui_function_pred_positive / max(1, ui_function_valid_elements),
|
| "ui_function_ref_positive_rate": ui_function_ref_positive / max(1, ui_function_valid_elements),
|
| "search_function_precision": search_function_tp / max(1, search_function_tp + search_function_fp),
|
| "search_function_recall": search_function_tp / max(1, search_function_tp + search_function_fn),
|
| "search_function_f1": (2 * search_function_tp) / max(1, 2 * search_function_tp + search_function_fp + search_function_fn),
|
| "search_function_pred_positive_rate": search_function_pred_positive / max(1, ui_function_valid_elements),
|
| "search_function_ref_positive_rate": search_function_ref_positive / max(1, ui_function_valid_elements),
|
| "search_precision": search_tp / max(1, search_tp + search_fp),
|
| "search_recall": search_tp / max(1, search_tp + search_fn),
|
| "search_f1": (2 * search_tp) / max(1, 2 * search_tp + search_fp + search_fn),
|
| "search_tp": search_tp,
|
| "search_fp": search_fp,
|
| "search_fn": search_fn,
|
| "pred_search_count": pred_search_count, |
| "ref_search_count": ref_search_count, |
| "pred_function_count": pred_function_count, |
| "ref_function_count": ref_function_count, |
| "function_overgen_rate": max(0, pred_function_count - ref_function_count) / max(1, ref_function_count), |
| "function_count_ratio": pred_function_count / max(1, ref_function_count), |
| "search_overgen_rate": max(0, pred_search_count - ref_search_count) / max(1, total), |
| "pred_bare_search_count": pred_bare_search_count, |
| "bare_search_rate": pred_bare_search_count / max(1, pred_function_count), |
| "max_target_tokens": int(args.max_target_tokens),
|
| "eval_max_new_tokens": eval_max_new_tokens,
|
| "num_beams": int(args.num_beams),
|
| "scheduler_epochs": int(getattr(args, "scheduler_epochs", 0) or 0), |
| "lr_scheduler_type": str(getattr(args, "lr_scheduler_type", "linear") or "linear"), |
| "context_mode": str(getattr(args, "context_mode", "mean") or "mean"), |
| "context_text_format": str(getattr(args, "context_text_format", "rich") or "rich"), |
| "context_include_screen_text": bool(getattr(args, "context_include_screen_text", False)), |
| "context_screen_text_items": int(getattr(args, "context_screen_text_items", 32) or 32), |
| "context_screen_text_dropout_rate": float(getattr(args, "context_screen_text_dropout_rate", 0.0) or 0.0), |
| "generation_block_extra_ids": bool(getattr(args, "generation_block_extra_ids", False)), |
| "generation_block_title_prefix": bool(getattr(args, "generation_block_title_prefix", False)), |
| "generation_no_repeat_ngram_size": int(getattr(args, "generation_no_repeat_ngram_size", 0) or 0), |
| "generation_repetition_penalty": float(getattr(args, "generation_repetition_penalty", 1.0) or 1.0),
|
| "generation_min_new_tokens": int(getattr(args, "generation_min_new_tokens", 0) or 0),
|
| "generation_force_json_start": bool(getattr(args, "generation_force_json_start", False)),
|
| "context_summary_repair": context_summary_repair,
|
| "context_repair_applied_rate": repair_count / max(1, total),
|
| "canonicalize_targets": bool(getattr(args, "canonicalize_targets", False)),
|
| "target_schema": str(getattr(args, "target_schema", "zh") or "zh"),
|
| "decoder_output_mode": "summary" if summary_output_mode else ("natural_text" if natural_output_mode else "json"),
|
| "task_intent_context": bool(getattr(args, "task_intent_context", False)),
|
| "drop_bare_search_functions": bool(getattr(args, "drop_bare_search_functions", False)),
|
| "structured_function_mode": str(getattr(args, "structured_function_mode", "decoder") or "decoder"),
|
| "structured_function_threshold": float(getattr(args, "structured_function_threshold", 0.5) or 0.5),
|
| "structured_search_threshold": float(getattr(args, "structured_search_threshold", 0.5) or 0.5),
|
| "structured_max_functions": int(getattr(args, "structured_max_functions", 12) or 12),
|
| "structured_strict_search_candidates": bool(getattr(args, "structured_strict_search_candidates", False)),
|
| "structured_evidence_mode": str(getattr(args, "structured_evidence_mode", "decoder") or "decoder"),
|
| "structured_evidence_threshold": float(getattr(args, "structured_evidence_threshold", 0.5) or 0.5),
|
| "structured_max_evidence": int(getattr(args, "structured_max_evidence", 8) or 8),
|
| "structured_evidence_fallback_top1": bool(getattr(args, "structured_evidence_fallback_top1", True)),
|
| "ui_function_loss_weight": float(getattr(args, "ui_function_loss_weight", 0.0) or 0.0),
|
| "search_function_loss_weight": float(getattr(args, "search_function_loss_weight", 0.0) or 0.0),
|
| "search_function_pos_weight": float(getattr(args, "search_function_pos_weight", 1.0) or 1.0),
|
| "lr_ui_function_head": float(getattr(args, "lr_ui_function_head", 0.0) or 0.0),
|
| "max_visual_tokens": int(getattr(args, "max_visual_tokens", 0) or 0),
|
| "model_variant": str(getattr(args, "model_variant", "")),
|
| "vision_enabled": not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only",
|
| "image_size": int(getattr(args, "image_size", 0) or 0),
|
| "direct_visual_tokens": bool(getattr(args, "direct_visual_tokens", False)),
|
| "direct_element_tokens": bool(getattr(args, "direct_element_tokens", False)),
|
| "direct_context_passthrough": bool(getattr(args, "direct_context_passthrough", False)),
|
| "include_pooled_memory": bool(getattr(args, "include_pooled_memory", True)),
|
| "native_context_forward": bool(getattr(args, "native_context_forward", False)), |
| "disable_vision": bool(getattr(args, "disable_vision", False)), |
| "freeze_decoder": bool(getattr(args, "freeze_decoder", False)), |
| "init_resize_mismatched_non_decoder": bool(getattr(args, "init_resize_mismatched_non_decoder", False)), |
| "grad_clip_strategy": str(getattr(args, "grad_clip_strategy", "global") or "global"),
|
| "max_grad_norm": float(getattr(args, "max_grad_norm", 1.0) or 0.0),
|
| "function_signal_to_decoder": bool(getattr(args, "function_signal_to_decoder", False)), |
| "function_signal_scale": float(getattr(args, "function_signal_scale", 1.0) or 1.0), |
| "search_signal_to_decoder": bool(getattr(args, "search_signal_to_decoder", False)), |
| "search_signal_scale": float(getattr(args, "search_signal_scale", 1.0) or 1.0), |
| "visual_memory_scale": float(getattr(args, "visual_memory_scale", 1.0) or 1.0), |
| "element_memory_scale": float(getattr(args, "element_memory_scale", 1.0) or 1.0), |
| "pooled_memory_scale": float(getattr(args, "pooled_memory_scale", 1.0) or 1.0), |
| "decoder_gradient_checkpointing": bool(getattr(args, "decoder_gradient_checkpointing", False)), |
| "vision_gradient_checkpointing": bool(getattr(args, "vision_gradient_checkpointing", False)),
|
| "cuda_memory_fraction": float(getattr(args, "cuda_memory_fraction", 0.0) or 0.0),
|
| "max_train_samples": int(getattr(args, "max_train_samples", 0) or 0),
|
| }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| metrics["summary_quality_score"] = float(metrics["rouge_l_char"])
|
|
|
|
|
|
|
|
|
|
|
| metrics["summary_template_match_rate"] = template_exact_match / max(1, total)
|
| metrics["summary_app_recall_rate"] = template_app_in_pred / max(1, total)
|
| metrics["summary_instruction_recall_rate"] = template_instruction_in_pred / max(1, total)
|
| if summary_output_mode:
|
| metrics["rich_quality_score"] = float(metrics["rouge_l_char"])
|
| else:
|
| metrics["rich_quality_score"] = (
|
| 0.45 * metrics["rouge_l_char"]
|
| + 0.25 * metrics["json_valid_rate"]
|
| + 0.30 * metrics["evidence_precision"]
|
| )
|
| metrics["rich_function_score"] = ( |
| 0.70 * metrics["rich_quality_score"] |
| + 0.20 * metrics["search_f1"] |
| + 0.10 * metrics["ui_function_f1"] |
| ) |
| overgen_penalty = min(0.12, 0.04 * float(metrics["function_overgen_rate"])) |
| search_penalty = min(0.06, 0.02 * float(metrics["search_overgen_rate"])) |
| bare_search_penalty = min(0.03, 0.03 * float(metrics["bare_search_rate"])) |
| metrics["grounded_quality_score"] = max( |
| 0.0, |
| float(metrics["summary_quality_score"]) + 0.05 * float(metrics["evidence_precision"]) - overgen_penalty - search_penalty - bare_search_penalty, |
| ) |
| metrics["grounded_overgen_penalty"] = overgen_penalty + search_penalty + bare_search_penalty |
| if is_main(rank):
|
| write_json(Path(args.output_dir) / "val_preview.json", {"metrics": metrics, "predictions": predictions})
|
| raw_model.train()
|
| return metrics
|
|
|
|
|
| def selection_metric_name(args: argparse.Namespace) -> str:
|
| return str(getattr(args, "model_selection_metric", "rich_quality_score") or "rich_quality_score")
|
|
|
|
|
| def selection_metric_value(metrics: Dict[str, Any], args: argparse.Namespace) -> float:
|
| metric_name = selection_metric_name(args)
|
| value = metrics.get(metric_name)
|
| if value is None:
|
| value = metrics.get("rich_quality_score", 0.0)
|
| return float(value)
|
|
|
|
|
| def metric_is_better(current: float, best: float, args: argparse.Namespace) -> bool:
|
| mode = str(getattr(args, "model_selection_mode", "max") or "max").lower()
|
| min_delta = float(getattr(args, "early_stopping_min_delta", 0.0) or 0.0)
|
| if mode == "min":
|
| return current < best - min_delta
|
| return current > best + min_delta
|
|
|
|
|
| def train(args: argparse.Namespace) -> None:
|
| set_seed(args.seed)
|
| distributed, rank, world_size, local_rank = init_distributed()
|
| device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
|
| if device.type == "cuda":
|
| cuda_memory_fraction = float(getattr(args, "cuda_memory_fraction", 0.0) or 0.0)
|
| if cuda_memory_fraction < 0.0 or cuda_memory_fraction > 1.0:
|
| raise ValueError("cuda_memory_fraction must be in [0, 1]. Use 0 to disable the limit.")
|
| if cuda_memory_fraction > 0.0:
|
| torch.cuda.set_per_process_memory_fraction(cuda_memory_fraction, device=device)
|
| if is_main(rank):
|
| total_gb = torch.cuda.get_device_properties(device).total_memory / 1024**3
|
| print(f"CUDA memory fraction limit: {cuda_memory_fraction:.3f} (~{total_gb * cuda_memory_fraction:.2f}GB of {total_gb:.2f}GB)")
|
| torch.backends.cuda.matmul.allow_tf32 = True
|
| torch.backends.cudnn.allow_tf32 = True
|
| if hasattr(torch, "set_float32_matmul_precision"):
|
| torch.set_float32_matmul_precision("high")
|
| output_dir = Path(args.output_dir)
|
| if is_main(rank):
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
| write_json(output_dir / "config.json", vars(args))
|
|
|
| max_train_samples = int(getattr(args, "max_train_samples", 0) or 0) |
| train_dataset = RichScreenshotDataset(args.train_file, max_samples=max_train_samples, sample_seed=args.seed if max_train_samples else None) |
| valid_dataset = RichScreenshotDataset(args.valid_file, max_samples=args.max_valid_samples) |
| train_data_diagnostics = dataset_diagnostics(train_dataset.rows) |
| valid_data_diagnostics = dataset_diagnostics(valid_dataset.rows) |
| validate_dataset_for_training("train", train_data_diagnostics, args) |
| validate_dataset_for_training("valid", valid_data_diagnostics, args) |
| tokenizer = load_seq2seq_tokenizer(args.decoder_model) |
| train_token_diagnostics = tokenizer_diagnostics(train_dataset.rows, tokenizer, args) |
| valid_token_diagnostics = tokenizer_diagnostics(valid_dataset.rows, tokenizer, args) |
| validate_token_lengths("train", train_token_diagnostics, args) |
| validate_token_lengths("valid", valid_token_diagnostics, args) |
| if is_main(rank): |
| diagnostics_payload = { |
| "train_file": args.train_file, |
| "valid_file": args.valid_file, |
| "strict_data_checks": bool(getattr(args, "strict_data_checks", True)), |
| "train": train_data_diagnostics, |
| "valid": valid_data_diagnostics, |
| "tokenizer": args.decoder_model, |
| "max_target_tokens": int(getattr(args, "max_target_tokens", 0) or 0), |
| "eval_max_new_tokens": int(getattr(args, "eval_max_new_tokens", 0) or 0), |
| "max_target_truncation_rate": float(getattr(args, "max_target_truncation_rate", 0.01) or 0.0), |
| "train_token_lengths": train_token_diagnostics, |
| "valid_token_lengths": valid_token_diagnostics, |
| } |
| write_json(output_dir / "data_diagnostics.json", diagnostics_payload) |
| print("Data diagnostics:") |
| print(json.dumps(diagnostics_payload, ensure_ascii=False, indent=2)) |
| image_processor = AutoImageProcessor.from_pretrained(args.vision_model) |
| train_collator = RichCollator(tokenizer, image_processor, args, is_training=True) |
| valid_collator = RichCollator(tokenizer, image_processor, args, is_training=False) |
| train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=args.seed) if distributed else None
|
| valid_sampler = DistributedSampler(valid_dataset, shuffle=False) if distributed else None
|
| train_loader = DataLoader(
|
| train_dataset,
|
| batch_size=args.batch_size,
|
| shuffle=train_sampler is None,
|
| sampler=train_sampler,
|
| collate_fn=train_collator, |
| num_workers=args.num_workers,
|
| pin_memory=torch.cuda.is_available(),
|
| )
|
| valid_loader = DataLoader(
|
| valid_dataset,
|
| batch_size=args.eval_batch_size if getattr(args, "eval_batch_size", 0) else max(1, args.batch_size // 2),
|
| shuffle=False,
|
| sampler=valid_sampler,
|
| collate_fn=valid_collator, |
| num_workers=args.num_workers,
|
| pin_memory=torch.cuda.is_available(),
|
| )
|
| model = RichGroundedModel(args)
|
| init_checkpoint = str(getattr(args, "init_checkpoint", "") or "")
|
| if init_checkpoint:
|
| checkpoint_decoder_model = ""
|
| init_config_path = Path(init_checkpoint) / "rich_config.json"
|
| init_config: Dict[str, Any] = {}
|
| if init_config_path.exists():
|
| init_config = json.loads(init_config_path.read_text(encoding="utf-8"))
|
| checkpoint_decoder_model = str(init_config.get("decoder_model", "") or "")
|
| allow_missing_prefixes_list: List[str] = []
|
| if checkpoint_decoder_model and checkpoint_decoder_model != str(getattr(args, "decoder_model", "") or ""):
|
| allow_missing_prefixes_list.append("decoder.")
|
| if bool(init_config.get("disable_vision", False)) and not bool(getattr(args, "disable_vision", False)):
|
| allow_missing_prefixes_list.append("vision.")
|
| state = torch.load(Path(init_checkpoint) / "pytorch_model.bin", map_location="cpu")
|
| missing_keys, _, skipped_incompatible, resized_incompatible = load_compatible_model_state(
|
| model,
|
| state,
|
| allow_missing_prefixes=tuple(allow_missing_prefixes_list),
|
| resize_mismatched_non_decoder=bool(getattr(args, "init_resize_mismatched_non_decoder", False)),
|
| )
|
| if is_main(rank):
|
| missing_preview = missing_keys[:12]
|
| skipped_preview = skipped_incompatible[:12]
|
| resized_preview = resized_incompatible[:12]
|
| print(
|
| f"Loaded init checkpoint: {init_checkpoint}; "
|
| f"missing_count={len(missing_keys)} preview={missing_preview}; "
|
| f"skipped_incompatible_count={len(skipped_incompatible)} preview={skipped_preview}; "
|
| f"resized_incompatible_count={len(resized_incompatible)} preview={resized_preview}"
|
| )
|
| model = model.to(device)
|
| trainable_stats = trainable_parameter_stats(model)
|
| if is_main(rank):
|
| print(
|
| "Trainable parameters: "
|
| f"total={trainable_stats['trainable_params_total']:,}, "
|
| f"vision={trainable_stats['trainable_params_vision']:,}, "
|
| f"decoder={trainable_stats['trainable_params_decoder']:,}, "
|
| f"other={trainable_stats['trainable_params_other']:,}",
|
| flush=True,
|
| )
|
| if distributed:
|
| model = DistributedDataParallel(model, device_ids=[local_rank], find_unused_parameters=True)
|
| elif bool(getattr(args, "data_parallel", False)) and device.type == "cuda" and torch.cuda.device_count() > 1:
|
| model = nn.DataParallel(model)
|
| if is_main(rank):
|
| print(f"Using nn.DataParallel on {torch.cuda.device_count()} CUDA devices.", flush=True)
|
| optimizer = build_optimizer(unwrap_parallel_model(model), args) |
| scheduler_epochs = int(getattr(args, "scheduler_epochs", 0) or args.epochs) |
| scheduler_epochs = max(int(args.epochs), scheduler_epochs) |
| total_update_steps = math.ceil(len(train_loader) / args.grad_accum) * scheduler_epochs |
| scheduler = get_lr_schedule(optimizer, total_update_steps, args.warmup_ratio, getattr(args, "lr_scheduler_type", "linear")) |
| amp_dtype_name = str(getattr(args, "amp_dtype", "auto") or "auto").lower()
|
| if amp_dtype_name == "auto":
|
| amp_dtype_name = "fp16" if bool(getattr(args, "fp16", False)) else "fp32"
|
| if amp_dtype_name not in {"fp32", "fp16", "bf16"}:
|
| raise ValueError("amp_dtype must be one of: auto, fp32, fp16, bf16")
|
| amp_enabled = device.type == "cuda" and amp_dtype_name in {"fp16", "bf16"}
|
| amp_dtype = torch.float16 if amp_dtype_name == "fp16" else torch.bfloat16
|
| scaler = torch.amp.GradScaler("cuda", enabled=amp_enabled and amp_dtype_name == "fp16")
|
| model_selection_mode = str(getattr(args, "model_selection_mode", "max") or "max").lower()
|
| best_score = math.inf if model_selection_mode == "min" else -math.inf
|
| epochs_without_improvement = 0
|
| global_step = 0
|
| train_start_time = time.time()
|
| recent_losses: deque[float] = deque(maxlen=100)
|
| optimizer.zero_grad(set_to_none=True)
|
| if device.type == "cuda":
|
| gc.collect()
|
| torch.cuda.empty_cache()
|
| torch.cuda.reset_peak_memory_stats(device)
|
| skip_optimizer_window = False
|
| skipped_nonfinite_windows = 0
|
|
|
| for epoch in range(args.epochs):
|
| if train_sampler is not None:
|
| train_sampler.set_epoch(epoch)
|
| progress = tqdm(train_loader, desc=f"epoch {epoch + 1}/{args.epochs}", disable=not is_main(rank))
|
| for step, batch in enumerate(progress, start=1):
|
| batch = move_batch(batch, device)
|
| with torch.amp.autocast("cuda", enabled=amp_enabled, dtype=amp_dtype):
|
| out = model(**{k: v for k, v in batch.items() if k != "rows"})
|
| out = reduce_parallel_losses(out)
|
| loss = out["loss"] / args.grad_accum
|
| loss_is_finite = bool(torch.isfinite(loss.detach()).all().item())
|
| if not loss_is_finite:
|
| skip_optimizer_window = True
|
| optimizer.zero_grad(set_to_none=True)
|
| if is_main(rank):
|
| append_jsonl(
|
| output_dir / "metrics.jsonl",
|
| {
|
| "event": "skip_nonfinite_microbatch",
|
| "epoch": epoch + 1,
|
| "micro_step": step,
|
| "loss": float(out["loss"].detach().cpu()),
|
| "generation_loss": float(out["generation_loss"].detach().cpu()),
|
| "evidence_loss": float(out["evidence_loss"].detach().cpu()),
|
| "ui_function_loss": float(out["ui_function_loss"].detach().cpu()),
|
| "search_function_loss": float(out["search_function_loss"].detach().cpu()),
|
| "section_loss": float(out["section_loss"].detach().cpu()),
|
| "numeric_loss": float(out["numeric_loss"].detach().cpu()),
|
| },
|
| )
|
| elif not skip_optimizer_window:
|
| scaler.scale(loss).backward()
|
| step_log = None
|
| if step % args.grad_accum == 0:
|
| if skip_optimizer_window:
|
| skipped_nonfinite_windows += 1
|
| optimizer.zero_grad(set_to_none=True)
|
| skip_optimizer_window = False
|
| if is_main(rank):
|
| append_jsonl(
|
| output_dir / "metrics.jsonl",
|
| {
|
| "event": "skip_nonfinite_optimizer_window",
|
| "epoch": epoch + 1,
|
| "micro_step": step,
|
| "skipped_nonfinite_windows": skipped_nonfinite_windows,
|
| "lr": max(scheduler.get_last_lr()), |
| },
|
| )
|
| else: |
| lrs_used = list(scheduler.get_last_lr()) |
| scaler.unscale_(optimizer) |
| grad_norm = clip_gradients(model, optimizer, args) |
| grad_is_finite = bool(torch.isfinite(grad_norm.detach()).all().item()) if torch.is_tensor(grad_norm) else math.isfinite(float(grad_norm))
|
| if not grad_is_finite:
|
| skipped_nonfinite_windows += 1
|
| optimizer.zero_grad(set_to_none=True)
|
| if scaler.is_enabled():
|
| scaler.update()
|
| if is_main(rank):
|
| append_jsonl(
|
| output_dir / "metrics.jsonl",
|
| {
|
| "event": "skip_nonfinite_grad",
|
| "epoch": epoch + 1,
|
| "micro_step": step,
|
| "grad_norm": float(grad_norm.detach().cpu()) if torch.is_tensor(grad_norm) else float(grad_norm),
|
| "skipped_nonfinite_windows": skipped_nonfinite_windows,
|
| "lr": max(scheduler.get_last_lr()), |
| },
|
| )
|
| else: |
| scaler.step(optimizer) |
| scaler.update() |
| optimizer.zero_grad(set_to_none=True) |
| scheduler.step() |
| global_step += 1
|
| if is_main(rank):
|
| loss_value = float(out["loss"].detach().cpu())
|
| recent_losses.append(loss_value)
|
| label_counts = (batch["labels"] != -100).sum(dim=1).detach().float().cpu()
|
| context_counts = batch["context_attention_mask"].sum(dim=1).detach().float().cpu()
|
| element_counts = batch["element_mask"].sum(dim=1).detach().float().cpu()
|
| loss_window = list(recent_losses)
|
| loss_window_20 = loss_window[-20:]
|
| elapsed_seconds = max(time.time() - train_start_time, 1e-6)
|
| next_lrs = list(scheduler.get_last_lr()) |
| step_log = { |
| "step": global_step,
|
| "epoch": epoch + 1,
|
| "loss": loss_value,
|
| "loss_ma20": float(np.mean(loss_window_20)) if loss_window_20 else loss_value,
|
| "loss_ma100": float(np.mean(loss_window)) if loss_window else loss_value,
|
| "generation_loss": float(out["generation_loss"].cpu()),
|
| "evidence_loss": float(out["evidence_loss"].cpu()),
|
| "ui_function_loss": float(out["ui_function_loss"].cpu()),
|
| "search_function_loss": float(out["search_function_loss"].cpu()),
|
| "section_loss": float(out["section_loss"].cpu()),
|
| "numeric_loss": float(out["numeric_loss"].cpu()),
|
| "lr": max(lrs_used) if lrs_used else 0.0, |
| "lr_other": lrs_used[0] if len(lrs_used) > 0 else 0.0, |
| "lr_vision": lrs_used[1] if len(lrs_used) > 1 else 0.0, |
| "lr_decoder": lrs_used[2] if len(lrs_used) > 2 else 0.0, |
| "lr_ui_function_head": lrs_used[3] if len(lrs_used) > 3 else 0.0, |
| "lr_next": max(next_lrs) if next_lrs else 0.0, |
| "lr_scheduler_type": str(getattr(args, "lr_scheduler_type", "linear") or "linear"), |
| "grad_norm": float(grad_norm.detach().cpu()) if torch.is_tensor(grad_norm) else float(grad_norm),
|
| "skipped_nonfinite_windows": skipped_nonfinite_windows,
|
| "optimizer_steps_total": total_update_steps,
|
| "train_progress_pct": 100.0 * global_step / max(1, total_update_steps),
|
| "elapsed_seconds": elapsed_seconds,
|
| "optimizer_steps_per_sec": global_step / elapsed_seconds,
|
| "target_tokens_mean": float(label_counts.mean().item()) if label_counts.numel() else 0.0,
|
| "target_tokens_max": int(label_counts.max().item()) if label_counts.numel() else 0,
|
| "target_at_max_rate": float((label_counts >= int(args.max_target_tokens)).float().mean().item()) if label_counts.numel() else 0.0,
|
| "context_tokens_mean": float(context_counts.mean().item()) if context_counts.numel() else 0.0,
|
| "context_tokens_max": int(context_counts.max().item()) if context_counts.numel() else 0,
|
| "context_at_max_rate": float((context_counts >= int(args.max_context_tokens)).float().mean().item()) if context_counts.numel() else 0.0,
|
| "element_count_mean": float(element_counts.mean().item()) if element_counts.numel() else 0.0,
|
| "element_count_max": int(element_counts.max().item()) if element_counts.numel() else 0,
|
| "model_variant": str(getattr(args, "model_variant", "")),
|
| "vision_enabled": not bool(getattr(args, "disable_vision", False)) and str(getattr(args, "model_variant", "")) != "annotation_only", |
| "native_context_forward": bool(getattr(args, "native_context_forward", False)), |
| "freeze_decoder": bool(getattr(args, "freeze_decoder", False)), |
| "image_size": int(getattr(args, "image_size", 0) or 0), |
| "image_crops": int(batch["pixel_values"].shape[1]) if torch.is_tensor(batch.get("pixel_values")) and batch["pixel_values"].ndim >= 5 else 0,
|
| "max_visual_tokens": int(getattr(args, "max_visual_tokens", 0) or 0),
|
| "direct_visual_tokens": bool(getattr(args, "direct_visual_tokens", False)),
|
| "direct_element_tokens": bool(getattr(args, "direct_element_tokens", False)), |
| "visual_memory_scale": float(getattr(args, "visual_memory_scale", 1.0) or 1.0), |
| "element_memory_scale": float(getattr(args, "element_memory_scale", 1.0) or 1.0), |
| "pooled_memory_scale": float(getattr(args, "pooled_memory_scale", 1.0) or 1.0), |
| "cuda_memory_fraction": float(getattr(args, "cuda_memory_fraction", 0.0) or 0.0), |
| "data_parallel": isinstance(model, nn.DataParallel),
|
| "distributed": bool(distributed),
|
| **trainable_stats,
|
| }
|
| del loss
|
| del out
|
| del batch
|
| empty_cache_steps = int(getattr(args, "cuda_empty_cache_steps", 0) or 0)
|
| if (
|
| device.type == "cuda"
|
| and empty_cache_steps > 0
|
| and step % args.grad_accum == 0
|
| and global_step % empty_cache_steps == 0
|
| ):
|
| gc.collect()
|
| torch.cuda.empty_cache()
|
| if step_log is not None and is_main(rank):
|
| mem = {}
|
| if torch.cuda.is_available():
|
| mem = {
|
| "gpu_allocated_gb": round(torch.cuda.memory_allocated(device) / 1024**3, 3),
|
| "gpu_reserved_gb": round(torch.cuda.memory_reserved(device) / 1024**3, 3),
|
| "gpu_peak_allocated_gb": round(torch.cuda.max_memory_allocated(device) / 1024**3, 3),
|
| "gpu_peak_reserved_gb": round(torch.cuda.max_memory_reserved(device) / 1024**3, 3),
|
| }
|
| log = {
|
| **step_log,
|
| **mem,
|
| }
|
| append_jsonl(output_dir / "metrics.jsonl", log)
|
| if device.type == "cuda":
|
| torch.cuda.reset_peak_memory_stats(device)
|
| progress.set_postfix(loss=f"{log['loss']:.3f}", ma20=f"{log['loss_ma20']:.3f}", pct=f"{log['train_progress_pct']:.1f}%")
|
| if (
|
| step_log is not None
|
| and args.save_checkpoints
|
| and args.save_every_steps
|
| and global_step > 0
|
| and global_step % args.save_every_steps == 0
|
| and is_main(rank)
|
| ):
|
| save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"step": global_step})
|
| if step_log is not None and args.eval_every_steps and global_step > 0 and global_step % args.eval_every_steps == 0: |
| if is_main(rank) and args.save_checkpoints: |
| save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"step": global_step, "pre_eval": True}) |
| metrics = evaluate(model, valid_loader, tokenizer, device, args, rank=rank) |
| current_score = selection_metric_value(metrics, args)
|
| metrics["selection_metric"] = selection_metric_name(args)
|
| metrics["selection_score"] = current_score
|
| if metric_is_better(current_score, best_score, args):
|
| best_score = current_score
|
| if is_main(rank) and args.save_checkpoints:
|
| save_checkpoint(output_dir, "checkpoint-best", model, tokenizer, image_processor, args, metrics)
|
| if device.type == "cuda" and empty_cache_steps > 0:
|
| gc.collect()
|
| torch.cuda.empty_cache()
|
| if is_main(rank):
|
| append_jsonl(output_dir / "metrics.jsonl", {"step_eval": global_step, **metrics})
|
| if is_main(rank) and args.save_checkpoints: |
| save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, {"epoch": epoch + 1, "step": global_step, "pre_epoch_eval": True}) |
| metrics = evaluate(model, valid_loader, tokenizer, device, args, rank=rank) |
| current_score = selection_metric_value(metrics, args)
|
| metrics["selection_metric"] = selection_metric_name(args)
|
| metrics["selection_score"] = current_score
|
| improved = metric_is_better(current_score, best_score, args)
|
| if improved:
|
| best_score = current_score
|
| epochs_without_improvement = 0
|
| else:
|
| epochs_without_improvement += 1
|
| if is_main(rank):
|
| append_jsonl(output_dir / "metrics.jsonl", {"epoch_eval": epoch + 1, **metrics})
|
| if args.save_checkpoints:
|
| save_checkpoint(output_dir, "checkpoint-last", model, tokenizer, image_processor, args, metrics)
|
| if args.save_checkpoints and improved:
|
| save_checkpoint(output_dir, "checkpoint-best", model, tokenizer, image_processor, args, metrics)
|
| patience = int(getattr(args, "early_stopping_patience", 0) or 0)
|
| if patience > 0:
|
| print(
|
| f"selection_metric={metrics['selection_metric']} score={current_score:.6f} "
|
| f"best={best_score:.6f} no_improve={epochs_without_improvement}/{patience}"
|
| )
|
| patience = int(getattr(args, "early_stopping_patience", 0) or 0)
|
| if patience > 0 and epochs_without_improvement >= patience:
|
| if is_main(rank):
|
| print(f"Early stopping at epoch {epoch + 1}: no improvement for {patience} evals.")
|
| break
|
| if distributed:
|
| dist.destroy_process_group()
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| parser = argparse.ArgumentParser(description="Train rich CMGUI screenshot summarization models.")
|
| for key, value in DEFAULT_CONFIG.items():
|
| arg_type = type(value)
|
| if isinstance(value, bool):
|
| parser.add_argument(f"--{key}", type=lambda x: str(x).lower() in {"1", "true", "yes"}, default=value)
|
| else:
|
| parser.add_argument(f"--{key}", type=arg_type, default=value)
|
| parser.add_argument("--bottleneck_queries", type=int, default=64)
|
| args = parser.parse_args()
|
| args.decoder_model = normalize_model_reference(args.decoder_model)
|
| args.vision_model = normalize_model_reference(args.vision_model)
|
| if args.model_variant not in {"annotation_only", "image_only", "late_fusion", "full"}:
|
| raise ValueError("model_variant must be one of annotation_only, image_only, late_fusion, full")
|
| args.context_mode = str(getattr(args, "context_mode", "mean") or "mean").lower() |
| if args.context_mode not in {"mean", "tokens", "tokens_direct", "tokens_encoder", "tokens_direct_encoder"}: |
| raise ValueError("context_mode must be one of: mean, tokens, tokens_direct, tokens_encoder, tokens_direct_encoder") |
| args.lr_scheduler_type = str(getattr(args, "lr_scheduler_type", "linear") or "linear").lower() |
| if args.lr_scheduler_type not in {"linear", "cosine"}: |
| raise ValueError("lr_scheduler_type must be one of: linear, cosine") |
| args.target_schema = str(getattr(args, "target_schema", "zh") or "zh").lower() |
| if args.target_schema not in {
|
| "zh",
|
| "alias",
|
| "aliases",
|
| "en",
|
| "english",
|
| "summary",
|
| "summary_zh",
|
| "summary-only",
|
| "summary_only",
|
| "natural_zh", |
| "rich_text_zh", |
| "zh_text", |
| "text_zh", |
| "summary_visible_zh", |
| "natural_summary_visible_zh", |
| }: |
| raise ValueError("target_schema must be one of: zh, aliases, summary_zh, natural_zh, summary_visible_zh") |
| args.grad_clip_strategy = str(getattr(args, "grad_clip_strategy", "global") or "global").lower()
|
| if args.grad_clip_strategy not in {"global", "per_group"}: |
| raise ValueError("grad_clip_strategy must be one of: global, per_group") |
| args.structured_function_mode = str(getattr(args, "structured_function_mode", "decoder") or "decoder").lower() |
| if args.structured_function_mode not in {"decoder", "heads"}: |
| raise ValueError("structured_function_mode must be one of: decoder, heads") |
| args.structured_evidence_mode = str(getattr(args, "structured_evidence_mode", "decoder") or "decoder").lower() |
| if args.structured_evidence_mode not in {"decoder", "heads"}: |
| raise ValueError("structured_evidence_mode must be one of: decoder, heads") |
| return args
|
|
|
|
|
| if __name__ == "__main__":
|
| train(parse_args())
|
|
|