|
|
| """Inference entrypoint for rich Chinese CMGUI screenshot summarization.
|
|
|
| The script supports two modes:
|
|
|
| 1. Model inference from a trained rich checkpoint.
|
| 2. OCR-template baseline without a checkpoint.
|
|
|
| The template mode is intentionally simple. It exists for debugging and as a
|
| report baseline, not as the final system.
|
| """
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import json
|
| import re
|
| import sys
|
| from pathlib import Path
|
| from typing import Any, Dict, Iterable, List, Optional, Tuple
|
|
|
| import numpy as np
|
| import torch
|
| from torch.utils.data import DataLoader
|
| from tqdm import tqdm
|
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent
|
| if str(SCRIPT_DIR) not in sys.path:
|
| sys.path.insert(0, str(SCRIPT_DIR))
|
|
|
| from train_rich import ( |
| RichCollator,
|
| RichScreenshotDataset,
|
| apply_structured_evidence_predictions,
|
| apply_structured_function_predictions,
|
| extract_summary,
|
| load_rich_checkpoint,
|
| move_batch,
|
| natural_prediction_from_text,
|
| prediction_from_summary,
|
| repair_prediction_with_context,
|
| safe_json_loads,
|
| target_schema_is_natural_text,
|
| target_schema_is_summary,
|
| ) |
|
|
|
|
| DEFAULT_TEST_FILE = "data/rich_cmgui/processed/test_rich_teacher500_natural_qwen8000.jsonl" |
| DEFAULT_OUTPUT_JSONL = "outputs/rich_infer_current/summary.jsonl" |
| DEFAULT_OUTPUT_MD = "outputs/rich_infer_current/summary.md" |
| DEFAULT_FUNCTION_THRESHOLD = 0.20 |
| DEFAULT_SEARCH_THRESHOLD = 0.20 |
| DEFAULT_EVIDENCE_THRESHOLD = 0.50 |
| DEFAULT_MAX_STRUCTURED_ITEMS = 8 |
|
|
|
|
| def str_to_bool(value: str) -> bool: |
| return str(value).lower() in {"1", "true", "yes"} |
|
|
|
|
| def find_latest_rich_checkpoint(runs_dir: str | Path = "runs") -> Optional[Path]: |
| """Find the newest usable stage3/stage4 rich checkpoint for demo/inference.""" |
| root = Path(runs_dir) |
| if not root.exists(): |
| return None |
| candidates: List[Tuple[float, int, Path]] = [] |
| for checkpoint in root.rglob("checkpoint-best"): |
| config_path = checkpoint / "rich_config.json" |
| if not config_path.exists(): |
| continue |
| normalized = checkpoint.as_posix().lower() |
| if "stage3_" in normalized: |
| stage_rank = 3 |
| elif "stage4_" in normalized: |
| stage_rank = 4 |
| else: |
| continue |
| try: |
| mtime = max(checkpoint.stat().st_mtime, config_path.stat().st_mtime) |
| except OSError: |
| continue |
| candidates.append((mtime, stage_rank, checkpoint)) |
| if not candidates: |
| return None |
| candidates.sort(key=lambda item: (item[0], item[1]), reverse=True) |
| return candidates[0][2] |
|
|
|
|
| SUMMARY_KEYS = ["画面总结", "summary_zh", "summary"] |
| VISIBLE_KEYS = ["可见文字", "visible_text"]
|
| INTERACTION_KEYS = ["互动数据", "interaction_data"]
|
| FUNCTION_KEYS = ["功能入口", "ui_functions"]
|
| EVIDENCE_KEYS = ["关键证据", "key_ui_clues", "evidence"]
|
|
|
| FUNCTION_KEYWORDS = [
|
| "搜索",
|
| "分享",
|
| "评论",
|
| "点赞",
|
| "收藏",
|
| "关注",
|
| "返回",
|
| "首页",
|
| "消息",
|
| "拍同款",
|
| "购物车",
|
| "购买",
|
| "下单",
|
| "展开",
|
| "关闭",
|
| "保存",
|
| ]
|
|
|
| INTERACTION_PATTERNS = [
|
| ("点赞数", re.compile(r"(?:点赞|赞|like)\s*[::]?\s*([0-9]+(?:\.[0-9]+)?[万kK]?)")),
|
| ("评论数", re.compile(r"(?:评论|comment)\s*[::]?\s*([0-9]+(?:\.[0-9]+)?[万kK]?)")),
|
| ("收藏数", re.compile(r"(?:收藏|star|favorite)\s*[::]?\s*([0-9]+(?:\.[0-9]+)?[万kK]?)")),
|
| ("分享数", re.compile(r"(?:分享|share)\s*[::]?\s*([0-9]+(?:\.[0-9]+)?[万kK]?)")),
|
| ("未读消息", re.compile(r"(?:消息|未读)\s*[::]?\s*([0-9]+(?:\.[0-9]+)?[万kK]?)")),
|
| ]
|
|
|
| COMMON_NOISE = {
|
| "Q",
|
| "0",
|
| "1",
|
| "2",
|
| "3",
|
| "4",
|
| "5",
|
| "6",
|
| "7",
|
| "8",
|
| "9",
|
| }
|
|
|
|
|
| def read_jsonl(path: Path) -> Iterable[Dict[str, Any]]:
|
| with path.open("r", encoding="utf-8") as f:
|
| for line_no, line in enumerate(f, start=1):
|
| line = line.strip()
|
| if not line:
|
| continue
|
| try:
|
| yield json.loads(line)
|
| except json.JSONDecodeError as exc:
|
| raise ValueError(f"Bad JSON at {path}:{line_no}: {exc}") from exc
|
|
|
|
|
| def write_jsonl(path: Path, rows: Iterable[Dict[str, Any]]) -> int:
|
| path.parent.mkdir(parents=True, exist_ok=True)
|
| count = 0
|
| with path.open("w", encoding="utf-8", newline="\n") as f:
|
| for row in rows:
|
| f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| count += 1
|
| return count
|
|
|
|
|
| def safe_text(value: Any) -> str:
|
| if value is None:
|
| return ""
|
| return re.sub(r"\s+", " ", str(value)).strip()
|
|
|
|
|
| def first_value(obj: Dict[str, Any], keys: List[str], default: Any = None) -> Any:
|
| for key in keys:
|
| if key in obj and obj[key] not in (None, ""):
|
| return obj[key]
|
| return default
|
|
|
|
|
| def normalize_prediction(obj: Optional[Dict[str, Any]]) -> Optional[Dict[str, Any]]:
|
| if not isinstance(obj, dict):
|
| return None
|
| return {
|
| "画面总结": safe_text(first_value(obj, SUMMARY_KEYS, "")),
|
| "可见文字": first_value(obj, VISIBLE_KEYS, []) or [],
|
| "互动数据": first_value(obj, INTERACTION_KEYS, []) or [],
|
| "功能入口": first_value(obj, FUNCTION_KEYS, []) or [],
|
| "关键证据": first_value(obj, EVIDENCE_KEYS, []) or [],
|
| }
|
|
|
|
|
| def item_id(item: Dict[str, Any]) -> str:
|
| return safe_text(item.get("id") or item.get("ocr_id"))
|
|
|
|
|
| def build_item_index(row: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
| index: Dict[str, Dict[str, Any]] = {}
|
| for item in row.get("ui_items", []) or []:
|
| if item.get("id"):
|
| index[str(item["id"])] = item
|
| if item.get("ocr_id"):
|
| index[str(item["ocr_id"])] = item
|
| for item in row.get("ocr_items", []) or []:
|
| if item.get("id") and str(item["id"]) not in index:
|
| index[str(item["id"])] = item
|
| return index
|
|
|
|
|
| def is_noise_text(text: str) -> bool:
|
| if not text:
|
| return True
|
| if text in COMMON_NOISE:
|
| return True
|
| if re.fullmatch(r"[0-9::/.\-\s]+", text) and len(text) <= 6:
|
| return True
|
| if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z]", text):
|
| return True
|
| return False
|
|
|
|
|
| def collect_visible_text(row: Dict[str, Any], max_items: int) -> List[str]:
|
| candidates: List[Tuple[float, str]] = []
|
| for item in row.get("ocr_items", []) or []:
|
| text = safe_text(item.get("text"))
|
| if is_noise_text(text):
|
| continue
|
| conf = float(item.get("conf", item.get("ocr_conf", 1.0)) or 1.0)
|
| if conf < 0.55:
|
| continue
|
| score = conf + min(len(text), 30) / 100.0
|
| if any(k in text for k in FUNCTION_KEYWORDS):
|
| score += 0.15
|
| candidates.append((score, text[:80]))
|
| seen = set()
|
| visible: List[str] = []
|
| for _, text in sorted(candidates, key=lambda x: x[0], reverse=True):
|
| if text not in seen:
|
| visible.append(text)
|
| seen.add(text)
|
| if len(visible) >= max_items:
|
| break
|
| return visible
|
|
|
|
|
| def collect_functions(row: Dict[str, Any], max_items: int) -> List[Dict[str, Any]]:
|
| funcs: List[Dict[str, Any]] = []
|
| seen = set()
|
| for item in row.get("ui_items", []) or []:
|
| text = safe_text(item.get("text"))
|
| if not text:
|
| continue
|
| for keyword in FUNCTION_KEYWORDS:
|
| if keyword in text and keyword not in seen:
|
| funcs.append({"name": keyword, "evidence_ids": [item_id(item)] if item_id(item) else []})
|
| seen.add(keyword)
|
| break
|
| if len(funcs) >= max_items:
|
| break
|
| return funcs
|
|
|
|
|
| def collect_interactions(row: Dict[str, Any], max_items: int) -> List[Dict[str, Any]]:
|
| interactions: List[Dict[str, Any]] = []
|
| seen = set()
|
| for item in row.get("ui_items", []) or []:
|
| text = safe_text(item.get("text"))
|
| if not text:
|
| continue
|
| for name, pattern in INTERACTION_PATTERNS:
|
| match = pattern.search(text)
|
| if match and (name, match.group(1)) not in seen:
|
| interactions.append(
|
| {
|
| "name": name,
|
| "value": match.group(1),
|
| "evidence_ids": [item_id(item)] if item_id(item) else [],
|
| }
|
| )
|
| seen.add((name, match.group(1)))
|
| break
|
| if len(interactions) >= max_items:
|
| break
|
| return interactions
|
|
|
|
|
| def template_prediction(row: Dict[str, Any], max_visible: int = 12) -> Dict[str, Any]:
|
| app = safe_text(row.get("app")) or "移动应用"
|
| visible = collect_visible_text(row, max_visible)
|
| funcs = collect_functions(row, max_items=8)
|
| interactions = collect_interactions(row, max_items=8)
|
| content_hint = "、".join(visible[:3]) if visible else "当前页面内容"
|
| func_hint = "、".join(f["name"] for f in funcs[:4]) if funcs else "查看和操作"
|
| instruction = safe_text(row.get("instruction"))
|
| if instruction:
|
| summary = f"这是一个{app}页面,屏幕主要展示{content_hint}等内容,可通过{func_hint}等入口继续操作;当前任务语境是{instruction}。"
|
| else:
|
| summary = f"这是一个{app}页面,屏幕主要展示{content_hint}等内容,可通过{func_hint}等入口继续操作。"
|
| evidence: List[str] = []
|
| for value in row.get("weak_evidence_ids", []) or []:
|
| if value not in evidence:
|
| evidence.append(value)
|
| for func in funcs:
|
| for value in func.get("evidence_ids", []) or []:
|
| if value and value not in evidence:
|
| evidence.append(value)
|
| for item in row.get("ui_items", [])[: max(0, 8 - len(evidence))]:
|
| value = item_id(item)
|
| if value and value not in evidence:
|
| evidence.append(value)
|
| return {
|
| "画面总结": summary,
|
| "可见文字": visible,
|
| "互动数据": interactions,
|
| "功能入口": funcs,
|
| "关键证据": evidence[:8],
|
| }
|
|
|
|
|
| def clue_details(row: Dict[str, Any], clue_ids: List[str], score_map: Optional[Dict[str, float]] = None) -> List[Dict[str, Any]]:
|
| index = build_item_index(row)
|
| out: List[Dict[str, Any]] = []
|
| for clue_id in clue_ids:
|
| clue_id = safe_text(clue_id)
|
| item = index.get(clue_id)
|
| if not item:
|
| continue
|
| out.append(
|
| {
|
| "element_id": clue_id,
|
| "type": item.get("type"),
|
| "text": item.get("text", ""),
|
| "bbox": item.get("bbox"),
|
| "location": item.get("location"),
|
| "source": item.get("source"),
|
| "score": round(float(score_map.get(clue_id, 0.0)), 4) if score_map else None,
|
| }
|
| )
|
| return out
|
|
|
|
|
| def row_result(
|
| row: Dict[str, Any],
|
| raw_text: str,
|
| pred_obj: Optional[Dict[str, Any]],
|
| json_valid: bool,
|
| evidence_scores: Optional[Dict[str, float]],
|
| allow_template_fallback: bool,
|
| source: str,
|
| ) -> Dict[str, Any]:
|
| structured = normalize_prediction(pred_obj)
|
| used_fallback = False
|
| if structured is None or not structured.get("画面总结"):
|
| if allow_template_fallback:
|
| structured = template_prediction(row)
|
| used_fallback = True
|
| else:
|
| structured = {
|
| "画面总结": "",
|
| "可见文字": [],
|
| "互动数据": [],
|
| "功能入口": [],
|
| "关键证据": [],
|
| }
|
| clue_ids = [safe_text(x) for x in structured.get("关键证据", []) if safe_text(x)]
|
| return {
|
| "screen_id": row.get("screen_id"),
|
| "image_path": row.get("image_path"),
|
| "app": row.get("app"),
|
| "summary": structured.get("画面总结", ""),
|
| "prediction": structured,
|
| "prediction_raw": raw_text,
|
| "json_valid": bool(json_valid and not used_fallback),
|
| "used_template_fallback": used_fallback,
|
| "source": "template" if used_fallback else source,
|
| "key_ui_clues": clue_details(row, clue_ids, evidence_scores),
|
| }
|
|
|
|
|
| def run_template(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
| rows = list(read_jsonl(Path(args.input_file)))
|
| if args.max_samples:
|
| rows = rows[: args.max_samples]
|
| results = []
|
| for row in tqdm(rows, desc="template infer"):
|
| pred = template_prediction(row, max_visible=args.max_visible_text)
|
| results.append(
|
| row_result(
|
| row=row,
|
| raw_text=json.dumps(pred, ensure_ascii=False),
|
| pred_obj=pred,
|
| json_valid=True,
|
| evidence_scores=None,
|
| allow_template_fallback=False,
|
| source="template",
|
| )
|
| )
|
| return results
|
|
|
|
|
| @torch.no_grad()
|
| def run_model(args: argparse.Namespace) -> List[Dict[str, Any]]:
|
| device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu"))
|
| model, tokenizer, image_processor, ckpt_args = load_rich_checkpoint(args.checkpoint, device)
|
| ckpt_args.num_workers = args.num_workers
|
| if args.generation_no_repeat_ngram_size is not None: |
| ckpt_args.generation_no_repeat_ngram_size = args.generation_no_repeat_ngram_size |
| if args.generation_repetition_penalty is not None: |
| ckpt_args.generation_repetition_penalty = args.generation_repetition_penalty |
| if args.generation_block_extra_ids is not None: |
| ckpt_args.generation_block_extra_ids = args.generation_block_extra_ids |
| if args.generation_block_title_prefix is not None: |
| ckpt_args.generation_block_title_prefix = args.generation_block_title_prefix |
| if args.generation_force_json_start is not None: |
| ckpt_args.generation_force_json_start = args.generation_force_json_start |
| runtime_names = [ |
| "canonicalize_targets", |
| "drop_bare_search_functions", |
| "structured_function_mode", |
| "structured_function_threshold", |
| "structured_search_threshold", |
| "structured_max_functions", |
| "structured_strict_search_candidates", |
| "structured_evidence_mode", |
| "structured_evidence_threshold", |
| "structured_max_evidence", |
| "structured_evidence_fallback_top1", |
| ] |
| for name in runtime_names: |
| value = getattr(args, name, None) |
| if value is not None and value != "": |
| setattr(ckpt_args, name, value) |
| dataset = RichScreenshotDataset(args.input_file, max_samples=args.max_samples) |
| collator = RichCollator(tokenizer, image_processor, ckpt_args)
|
| loader = DataLoader(
|
| dataset,
|
| batch_size=args.batch_size,
|
| shuffle=False,
|
| collate_fn=collator,
|
| num_workers=args.num_workers,
|
| pin_memory=device.type == "cuda",
|
| )
|
| results: List[Dict[str, Any]] = []
|
| model.eval()
|
| for batch in tqdm(loader, desc="model infer"):
|
| rows = batch["rows"]
|
| batch = move_batch(batch, device)
|
| texts = model.generate_text(batch, tokenizer, num_beams=args.num_beams, max_new_tokens=args.max_new_tokens)
|
| _, _, elem_tokens, elem_key_padding = model.build_memory(batch)
|
| evidence_scores_tensor = torch.sigmoid(model.evidence_head(elem_tokens).squeeze(-1)).detach().cpu()
|
| function_scores_tensor = torch.sigmoid(model.ui_function_head(elem_tokens).squeeze(-1)).detach().cpu()
|
| search_scores_tensor = torch.sigmoid(model.search_function_head(elem_tokens).squeeze(-1)).detach().cpu()
|
| scores = evidence_scores_tensor.numpy()
|
| masks = (~elem_key_padding).detach().cpu().numpy()
|
| for row_idx, (row, text) in enumerate(zip(rows, texts)):
|
| if target_schema_is_summary(getattr(ckpt_args, "target_schema", "zh")):
|
| pred_obj = prediction_from_summary(row, text)
|
| ok = True
|
| elif target_schema_is_natural_text(getattr(ckpt_args, "target_schema", "zh")):
|
| pred_obj = natural_prediction_from_text(text)
|
| ok = bool(extract_summary(pred_obj))
|
| else:
|
| pred_obj, ok = safe_json_loads(text)
|
| evidence_scores: Dict[str, float] = {}
|
| elements = row.get("ui_items", []) or []
|
| for elem_idx, elem in enumerate(elements[: scores.shape[1]]):
|
| if masks[row_idx, elem_idx]:
|
| elem_id = item_id(elem)
|
| if elem_id:
|
| evidence_scores[elem_id] = float(scores[row_idx, elem_idx])
|
| if pred_obj is None:
|
| top_ids = [
|
| item_id(elements[i])
|
| for i in np.argsort(scores[row_idx])[::-1]
|
| if i < len(elements) and masks[row_idx, i] and item_id(elements[i])
|
| ][: args.top_k_clues]
|
| pred_obj = {"关键证据": top_ids}
|
| repair_applied = False
|
| if args.context_summary_repair:
|
| pred_obj, repair_applied = repair_prediction_with_context(row, pred_obj)
|
| ok = True
|
| pred_obj = apply_structured_function_predictions(
|
| row,
|
| pred_obj,
|
| function_scores_tensor[row_idx],
|
| search_scores_tensor[row_idx],
|
| ckpt_args,
|
| )
|
| pred_obj = apply_structured_evidence_predictions(row, pred_obj, evidence_scores_tensor[row_idx], ckpt_args)
|
| results.append(
|
| row_result(
|
| row=row,
|
| raw_text=text,
|
| pred_obj=pred_obj,
|
| json_valid=ok,
|
| evidence_scores=evidence_scores,
|
| allow_template_fallback=args.allow_template_fallback,
|
| source="model+context_repair" if repair_applied else "model",
|
| )
|
| )
|
| return results
|
|
|
|
|
| def markdown_for_result(row: Dict[str, Any]) -> str:
|
| pred = row.get("prediction") or {}
|
| lines = [
|
| f"## {row.get('screen_id')}",
|
| "",
|
| "### 画面总结",
|
| safe_text(pred.get("画面总结")),
|
| "",
|
| ]
|
| visible = pred.get("可见文字") or []
|
| if visible:
|
| lines += ["### 可见文字 / 文案"]
|
| lines += [f"- {safe_text(x)}" for x in visible[:16]]
|
| lines.append("")
|
| interactions = pred.get("互动数据") or []
|
| if interactions:
|
| lines += ["### 互动数据"]
|
| for item in interactions:
|
| if isinstance(item, dict):
|
| value = safe_text(item.get("value"))
|
| suffix = f": {value}" if value else ""
|
| lines.append(f"- {safe_text(item.get('name'))}{suffix}")
|
| lines.append("")
|
| funcs = pred.get("功能入口") or []
|
| if funcs:
|
| lines += ["### 功能入口"]
|
| for item in funcs:
|
| if isinstance(item, dict):
|
| evidence = ", ".join(item.get("evidence_ids", []) or [])
|
| suffix = f" ({evidence})" if evidence else ""
|
| lines.append(f"- {safe_text(item.get('name'))}{suffix}")
|
| lines.append("")
|
| clues = row.get("key_ui_clues") or []
|
| if clues:
|
| lines += ["### 关键 UI 证据"]
|
| for clue in clues[:8]:
|
| text = safe_text(clue.get("text"))
|
| score = clue.get("score")
|
| score_text = f", score={score:.3f}" if isinstance(score, (float, int)) else ""
|
| lines.append(f"- {clue.get('element_id')}: {text}{score_text}")
|
| lines.append("")
|
| return "\n".join(lines).rstrip() + "\n"
|
|
|
|
|
| def write_markdown(path: Path, rows: List[Dict[str, Any]]) -> None:
|
| path.parent.mkdir(parents=True, exist_ok=True)
|
| content = ["# Rich Screenshot Summaries", ""]
|
| for row in rows:
|
| content.append(markdown_for_result(row))
|
| path.write_text("\n".join(content), encoding="utf-8", newline="\n")
|
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Run rich screenshot summarization inference.", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| parser.add_argument("--input_file", default=DEFAULT_TEST_FILE) |
| parser.add_argument("--checkpoint", default="") |
| parser.add_argument("--template_only", action="store_true") |
| parser.add_argument("--output_jsonl", default=DEFAULT_OUTPUT_JSONL) |
| parser.add_argument("--output_md", default=DEFAULT_OUTPUT_MD) |
| parser.add_argument("--max_samples", type=int, default=0) |
| parser.add_argument("--batch_size", type=int, default=2) |
| parser.add_argument("--num_workers", type=int, default=0) |
| parser.add_argument("--device", default="") |
| parser.add_argument("--num_beams", type=int, default=1) |
| parser.add_argument("--max_new_tokens", type=int, default=384) |
| parser.add_argument("--generation_no_repeat_ngram_size", type=int, default=3) |
| parser.add_argument("--generation_repetition_penalty", type=float, default=1.1) |
| parser.add_argument("--generation_block_extra_ids", type=str_to_bool, default=True) |
| parser.add_argument("--generation_block_title_prefix", type=str_to_bool, default=True) |
| parser.add_argument("--generation_force_json_start", type=str_to_bool, default=False) |
| parser.add_argument("--context_summary_repair", type=lambda x: str(x).lower() in {"1", "true", "yes"}, default=False) |
| parser.add_argument("--canonicalize_targets", type=str_to_bool, default=None) |
| parser.add_argument("--drop_bare_search_functions", type=str_to_bool, default=None) |
| parser.add_argument("--structured_function_mode", choices=["", "decoder", "heads"], default="heads") |
| parser.add_argument("--structured_function_threshold", type=float, default=DEFAULT_FUNCTION_THRESHOLD) |
| parser.add_argument("--structured_search_threshold", type=float, default=DEFAULT_SEARCH_THRESHOLD) |
| parser.add_argument("--structured_max_functions", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS) |
| parser.add_argument("--structured_strict_search_candidates", type=str_to_bool, default=None) |
| parser.add_argument("--structured_evidence_mode", choices=["", "decoder", "heads"], default="heads") |
| parser.add_argument("--structured_evidence_threshold", type=float, default=DEFAULT_EVIDENCE_THRESHOLD) |
| parser.add_argument("--structured_max_evidence", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS) |
| parser.add_argument("--structured_evidence_fallback_top1", type=str_to_bool, default=False) |
| parser.add_argument("--top_k_clues", type=int, default=5) |
| parser.add_argument("--max_visible_text", type=int, default=12) |
|
|
|
|
| parser.add_argument("--allow_template_fallback", type=lambda x: str(x).lower() in {"1", "true", "yes"}, default=False) |
| args = parser.parse_args() |
| if not args.template_only and not args.checkpoint: |
| checkpoint = find_latest_rich_checkpoint() |
| if checkpoint is None: |
| raise ValueError("No stage3/stage4 rich checkpoint found. Provide --checkpoint or use --template_only.") |
| args.checkpoint = str(checkpoint) |
| return args |
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
| if args.template_only:
|
| results = run_template(args)
|
| else:
|
| results = run_model(args)
|
| count = write_jsonl(Path(args.output_jsonl), results)
|
| if args.output_md:
|
| write_markdown(Path(args.output_md), results)
|
| print(
|
| json.dumps(
|
| {
|
| "input_file": args.input_file,
|
| "checkpoint": args.checkpoint or None,
|
| "template_only": args.template_only,
|
| "output_jsonl": args.output_jsonl,
|
| "output_md": args.output_md,
|
| "count": count,
|
| },
|
| ensure_ascii=False,
|
| indent=2,
|
| )
|
| )
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|