#!/usr/bin/env python """Local web GUI for rich CMGUI screenshot summarization.""" from __future__ import annotations import argparse import json import mimetypes import re import sys import threading import time import webbrowser from http import HTTPStatus from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from io import BytesIO from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from urllib.parse import unquote, urlparse import torch from PIL import Image SCRIPT_DIR = Path(__file__).resolve().parent if str(SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(SCRIPT_DIR)) from enrich_rich_ocr_evidence import build_ocr_ui_items, filter_ocr_items # noqa: E402 from infer_rich import find_latest_rich_checkpoint, row_result, template_prediction # noqa: E402 from prepare_rich_data import load_ocr_items, safe_text, sha256_file # noqa: E402 from train_rich import ( # noqa: E402 RichCollator, apply_structured_evidence_predictions, apply_structured_function_predictions, load_rich_checkpoint, move_batch, natural_prediction_from_text, prediction_from_summary, repair_prediction_with_context, safe_json_loads, target_schema_is_natural_text, target_schema_is_summary, ) DEFAULT_CHECKPOINT = "" DEFAULT_FUNCTION_THRESHOLD = 0.20 DEFAULT_SEARCH_THRESHOLD = 0.20 DEFAULT_EVIDENCE_THRESHOLD = 0.50 DEFAULT_MAX_STRUCTURED_ITEMS = 8 ALLOWED_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"} GENERIC_APP_NAMES = {"", "移动应用", "手机应用", "应用", "App", "APP", "app"} SUMMARY_SKIP_TEXTS = { "ADB Keyboard {ON}", "反馈", "共享中", "分享中", } GUI_FUNCTION_KEYWORDS = ["确定", "取消", "删除", "分享", "评论", "收藏", "点赞", "关注", "返回", "首页", "消息", "购物车", "购买", "下单", "关闭"] PASSIVE_SEARCH_TERMS = ["历史搜索", "搜索历史", "热门搜索", "热搜", "猜你想搜", "搜索推荐", "搜索记录"] HTML_PAGE = r""" Rich Screenshot Summarizer

Rich Screenshot Summarizer

就绪
未选择图片
结果会显示在这里
""" def str_to_bool(value: Any) -> bool: return str(value).lower() in {"1", "true", "yes", "y"} def json_bytes(obj: Dict[str, Any], status: int = 200) -> Tuple[int, bytes, str]: return status, json.dumps(obj, ensure_ascii=False).encode("utf-8"), "application/json; charset=utf-8" def safe_upload_name(filename: str) -> str: name = Path(filename or "upload").name name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._") return name or "upload" def display_app_name(app: str) -> str: text = safe_text(app) return "" if text in GENERIC_APP_NAMES else text def text_bbox(item: Dict[str, Any]) -> Tuple[float, float]: bbox = item.get("bbox") or [] if isinstance(bbox, list) and len(bbox) >= 4: try: return float(bbox[1]), float(bbox[0]) except (TypeError, ValueError): return 0.0, 0.0 return 0.0, 0.0 def skip_summary_text(text: str) -> bool: if not text: return True if text in SUMMARY_SKIP_TEXTS: return True if re.fullmatch(r"[0-9::/ ._\-]+", text): return True if re.fullmatch(r"\d{1,2}:\d{2}", text): return True if re.fullmatch(r"[A-Za-z0-9%+\- ]{1,5}", text): return True if len(text) == 1 and text not in {"搜"}: return True if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z]", text): return True return False def center_dialog_texts(row: Dict[str, Any]) -> List[str]: items: List[Tuple[float, float, str]] = [] for item in row.get("ocr_items", []) or []: text = safe_text(item.get("text"))[:80] if not text or skip_summary_text(text): continue bbox = item.get("bbox") or [] if not (isinstance(bbox, list) and len(bbox) == 4): continue try: x1, y1, x2, y2 = [float(value) for value in bbox] except (TypeError, ValueError): continue cx = (x1 + x2) / 2.0 cy = (y1 + y2) / 2.0 if 0.18 <= cx <= 0.82 and 0.32 <= cy <= 0.68: items.append((y1, x1, text)) texts = [] seen = set() for _, _, text in sorted(items, key=lambda value: (value[0], value[1])): if text not in seen: texts.append(text) seen.add(text) joined = " ".join(texts) has_dialog_action = "取消" in joined and "确定" in joined has_dialog_prompt = any(term in joined for term in ["确认", "是否", "删除", "提示", "商品吗"]) return texts if has_dialog_action and has_dialog_prompt else [] def ordered_ocr_texts(row: Dict[str, Any], max_items: int = 14) -> List[str]: dialog_texts = center_dialog_texts(row) candidates: List[Tuple[float, float, str]] = [] for item in row.get("ocr_items", []) or []: text = safe_text(item.get("text"))[:80] if skip_summary_text(text): continue conf = float(item.get("conf", item.get("ocr_conf", 1.0)) or 1.0) if conf < 0.5: continue y_pos, x_pos = text_bbox(item) candidates.append((y_pos, x_pos, text)) seen = set() texts: List[str] = [] for text in dialog_texts: if text not in seen: texts.append(text) seen.add(text) for _, _, text in sorted(candidates, key=lambda value: (value[0], value[1])): if text in seen: continue seen.add(text) texts.append(text) if len(texts) >= max_items: break return texts def gui_item_index(row: Dict[str, Any]) -> Dict[str, Dict[str, Any]]: index: Dict[str, Dict[str, Any]] = {} for item in list(row.get("ui_items", []) or []) + list(row.get("ocr_items", []) or []): for key in ["id", "ocr_id"]: item_id = safe_text(item.get(key)) if item_id and item_id not in index: index[item_id] = item return index def is_passive_search_text(text: str) -> bool: return any(term in text for term in PASSIVE_SEARCH_TERMS) def clean_function_list(row: Dict[str, Any], funcs: List[Any], max_items: int) -> List[Dict[str, Any]]: index = gui_item_index(row) cleaned: List[Dict[str, Any]] = [] seen = set() for item in funcs or []: if not isinstance(item, dict): name = safe_text(item) evidence_ids: List[str] = [] else: name = safe_text(item.get("name")) evidence_ids = [safe_text(value) for value in item.get("evidence_ids", []) if safe_text(value)] if not name or name in seen: continue evidence_texts = [safe_text(index.get(value, {}).get("text")) for value in evidence_ids] if any(skip_summary_text(text) for text in evidence_texts if text): continue if "搜索" in name and evidence_texts and all(is_passive_search_text(text) for text in evidence_texts): continue cleaned.append({"name": name, "evidence_ids": evidence_ids}) seen.add(name) if len(cleaned) >= max_items: break return cleaned def collect_gui_functions(row: Dict[str, Any], max_items: int) -> List[Dict[str, Any]]: funcs: List[Dict[str, Any]] = [] seen = set() for item in row.get("ui_items", []) or []: text = safe_text(item.get("text")) evidence_id = safe_text(item.get("id") or item.get("ocr_id")) if not text or not evidence_id or skip_summary_text(text): continue name = "" if (text in {"搜索", "搜"} or text.startswith("搜索")) and not is_passive_search_text(text): name = "搜索" else: for keyword in GUI_FUNCTION_KEYWORDS: if keyword in text: name = keyword break if not name or name in seen: continue funcs.append({"name": name, "evidence_ids": [evidence_id]}) seen.add(name) if len(funcs) >= max_items: break return funcs def clean_evidence_ids(row: Dict[str, Any], evidence_ids: List[Any], max_items: int) -> List[str]: index = gui_item_index(row) cleaned: List[str] = [] seen = set() for value in evidence_ids or []: evidence_id = safe_text(value) if not evidence_id or evidence_id in seen: continue text = safe_text(index.get(evidence_id, {}).get("text")) if text and skip_summary_text(text): continue cleaned.append(evidence_id) seen.add(evidence_id) if len(cleaned) >= max_items: break return cleaned def infer_page_type(texts: List[str]) -> str: joined = " ".join(texts) if "历史搜索" in joined or "搜索" in joined: if "热榜" in joined or "热搜" in joined or "猜你喜欢" in joined: return "搜索和推荐内容页面" return "搜索相关页面" if any(term in joined for term in ["购物车", "下单", "购买", "商品", "价格", "优惠"]): return "购物页面" if any(term in joined for term in ["热榜", "榜单", "猜你喜欢", "推荐"]): return "推荐内容页面" if any(term in joined for term in ["评论", "点赞", "收藏", "分享"]): return "内容互动页面" if any(term in joined for term in ["我的", "订单", "设置", "账号", "会员"]): return "个人或设置页面" return "移动应用页面" def quoted_join(values: List[str], limit: int) -> str: return "、".join(f"“{value}”" for value in values[:limit]) def build_grounded_summary(row: Dict[str, Any]) -> str: texts = ordered_ocr_texts(row, max_items=14) dialog_texts = center_dialog_texts(row) app = display_app_name(str(row.get("app") or "")) page_type = infer_page_type(texts) subject = f"{app}的{page_type}" if app else page_type if not texts: return f"这张截图展示的是{subject},但当前没有识别到足够清晰的屏幕文字。" if dialog_texts: prompt = next((text for text in dialog_texts if any(term in text for term in ["确认", "是否", "删除", "商品吗"])), dialog_texts[0]) actions = [text for text in dialog_texts if text in {"取消", "确定", "删除", "关闭"}] action_text = f",提供{quoted_join(actions, 4)}等按钮" if actions else "" return f"这张截图显示{subject}上弹出确认对话框,提示“{prompt}”{action_text};背景是购物车商品列表。" primary = quoted_join(texts, 5) summary = f"这张截图主要是{subject},屏幕上能看到{primary}等文字。" extra = texts[5:10] if extra: summary += f" 下方还出现{quoted_join(extra, 5)}等条目。" return summary def merge_items(primary: List[Any], secondary: List[Any], max_items: int) -> List[Any]: merged: List[Any] = [] seen = set() for item in list(primary or []) + list(secondary or []): if isinstance(item, dict): key = json.dumps(item, ensure_ascii=False, sort_keys=True) else: key = safe_text(item) if not key or key in seen: continue seen.add(key) merged.append(item) if len(merged) >= max_items: break return merged def gui_ground_prediction(row: Dict[str, Any], pred_obj: Optional[Dict[str, Any]], args: argparse.Namespace) -> Dict[str, Any]: pred = pred_obj if isinstance(pred_obj, dict) else {} template = template_prediction(row, max_visible=args.max_visible_text) grounded = { "画面总结": safe_text(pred.get("画面总结") or pred.get("summary_zh") or pred.get("summary")), "可见文字": pred.get("可见文字") or pred.get("visible_text") or [], "互动数据": pred.get("互动数据") or pred.get("interaction_data") or [], "功能入口": pred.get("功能入口") or pred.get("ui_functions") or [], "关键证据": pred.get("关键证据") or pred.get("key_ui_clues") or pred.get("evidence") or [], } model_summary = safe_text(grounded.get("画面总结")) if row.get("ocr_items") and (args.gui_summary_mode == "ocr" or (args.gui_summary_mode == "auto" and not model_summary)): grounded["画面总结"] = build_grounded_summary(row) if row.get("ocr_items"): grounded["可见文字"] = ordered_ocr_texts(row, max_items=args.max_visible_text) max_functions = int(getattr(args, "structured_max_functions", None) or 12) gui_functions = collect_gui_functions(row, max_functions) model_functions = clean_function_list(row, grounded.get("功能入口", []), max_functions) if args.merge_ocr_functions: grounded["功能入口"] = merge_items( model_functions, gui_functions, max_functions, ) elif not grounded.get("功能入口"): grounded["功能入口"] = gui_functions else: grounded["功能入口"] = model_functions if not grounded.get("互动数据"): grounded["互动数据"] = template.get("互动数据", []) function_evidence = [] for function in grounded.get("功能入口", []) or []: if isinstance(function, dict): function_evidence.extend(function.get("evidence_ids", []) or []) evidence_candidates = function_evidence + list(grounded.get("关键证据", []) or []) if not evidence_candidates: evidence_candidates = list(template.get("关键证据", []) or []) grounded["关键证据"] = clean_evidence_ids(row, evidence_candidates, 8) if not grounded.get("画面总结"): grounded["画面总结"] = template.get("画面总结", "") return grounded def parse_multipart(headers: Any, body: bytes) -> Dict[str, Any]: content_type = headers.get("Content-Type", "") match = re.search(r"boundary=(?P\"?)([^\";]+)(?P=q)", content_type) if not match: raise ValueError("Missing multipart boundary.") boundary = ("--" + match.group(2)).encode("utf-8") fields: Dict[str, Any] = {} for part in body.split(boundary): part = part.strip(b"\r\n") if not part or part == b"--": continue header_blob, _, value = part.partition(b"\r\n\r\n") if not header_blob: continue header_text = header_blob.decode("utf-8", errors="replace") disposition = "" for line in header_text.split("\r\n"): if line.lower().startswith("content-disposition:"): disposition = line.split(":", 1)[1] break name_match = re.search(r'name="([^"]+)"', disposition) if not name_match: continue field_name = name_match.group(1) filename_match = re.search(r'filename="([^"]*)"', disposition) if filename_match: fields[field_name] = { "filename": filename_match.group(1), "content": value.rstrip(b"\r\n"), } else: fields[field_name] = value.rstrip(b"\r\n").decode("utf-8", errors="replace") return fields class RichGuiPredictor: def __init__(self, args: argparse.Namespace): self.args = args self.device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")) self.model = None self.tokenizer = None self.image_processor = None self.ckpt_args = None self.collator = None self.lock = threading.Lock() def load_model(self) -> None: if self.args.template_only or self.model is not None: return if not self.args.checkpoint: raise FileNotFoundError("Checkpoint not set. Train a natural multimodal checkpoint first or pass --checkpoint.") checkpoint = Path(self.args.checkpoint) if not checkpoint.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") self.model, self.tokenizer, self.image_processor, self.ckpt_args = load_rich_checkpoint(str(checkpoint), self.device) self.apply_runtime_args(self.ckpt_args) self.collator = RichCollator(self.tokenizer, self.image_processor, self.ckpt_args) def apply_runtime_args(self, ckpt_args: argparse.Namespace) -> None: optional_names = [ "generation_no_repeat_ngram_size", "generation_repetition_penalty", "generation_block_extra_ids", "generation_block_title_prefix", "generation_force_json_start", "context_summary_repair", "canonicalize_targets", "drop_bare_search_functions", "structured_function_threshold", "structured_search_threshold", "structured_max_functions", "structured_strict_search_candidates", "structured_evidence_threshold", "structured_max_evidence", "structured_evidence_fallback_top1", ] for name in optional_names: value = getattr(self.args, name, None) if value is not None and value != "": setattr(ckpt_args, name, value) for name in ["structured_function_mode", "structured_evidence_mode"]: value = getattr(self.args, name, "") if value: setattr(ckpt_args, name, value) ckpt_args.num_workers = 0 def save_image(self, upload: Dict[str, Any]) -> Path: raw = upload.get("content") or b"" if not raw: raise ValueError("Uploaded image is empty.") filename = safe_upload_name(upload.get("filename", "upload.png")) suffix = Path(filename).suffix.lower() if suffix and suffix not in ALLOWED_IMAGE_EXTS: raise ValueError(f"Unsupported image type: {suffix}") upload_dir = Path(self.args.upload_dir) upload_dir.mkdir(parents=True, exist_ok=True) with Image.open(BytesIO(raw)) as img: image = img.convert("RGB") stamp = time.strftime("%Y%m%d_%H%M%S") upload_path = upload_dir / f"{stamp}_{int(time.time() * 1000) % 1000:03d}_{Path(filename).stem}.png" image.save(upload_path) return upload_path def build_row(self, image_path: Path, app: str, focus: str, ocr_engine: str) -> Tuple[Dict[str, Any], Optional[str]]: model_instruction = safe_text(focus) if self.args.use_focus_in_model else "" row: Dict[str, Any] = { "screen_id": image_path.stem, "image_path": str(image_path), "app": safe_text(app) or "移动应用", "instruction": model_instruction, "display_focus": safe_text(focus), "target": { "summary_zh": "", "visible_text": [], "interaction_data": [], "ui_functions": [], "key_ui_clues": [], }, "ocr_items": [], "ui_items": [], "weak_evidence_ids": [], } ocr_error = None if ocr_engine != "none": try: cache_args = argparse.Namespace(ocr_engine=ocr_engine, ocr_lang=self.args.ocr_lang) image_sha = sha256_file(image_path) ocr_items = load_ocr_items(image_path, image_sha, Path(self.args.ocr_cache_dir), cache_args) ocr_items = filter_ocr_items(ocr_items, self.args.min_ocr_conf)[: self.args.max_ocr_items] row["ocr_items"] = ocr_items row["ui_items"] = build_ocr_ui_items(row, ocr_items, self.args.max_ui_items) row["weak_evidence_ids"] = [item.get("id") for item in row["ui_items"][:8] if item.get("id")] except Exception as exc: # OCR should not prevent image-only inference. ocr_error = str(exc) return row, ocr_error @torch.no_grad() def summarize(self, image_path: Path, app: str, focus: str, ocr_engine: str) -> Dict[str, Any]: start = time.perf_counter() row, ocr_error = self.build_row(image_path, app, focus, ocr_engine) if self.args.template_only: pred = template_prediction(row, max_visible=self.args.max_visible_text) pred = gui_ground_prediction(row, pred, self.args) result = row_result( row=row, raw_text=json.dumps(pred, ensure_ascii=False), pred_obj=pred, json_valid=True, evidence_scores=None, allow_template_fallback=False, source="template", ) else: with self.lock: self.load_model() batch = self.collator([row]) batch = move_batch(batch, self.device) text = self.model.generate_text( batch, self.tokenizer, num_beams=self.args.num_beams, max_new_tokens=self.args.max_new_tokens, )[0] _, _, elem_tokens, elem_key_padding = self.model.build_memory(batch) masks = (~elem_key_padding)[0].detach().cpu().numpy() evidence_head = torch.sigmoid(self.model.evidence_head(elem_tokens).squeeze(-1))[0].detach().cpu() function_head = torch.sigmoid(self.model.ui_function_head(elem_tokens).squeeze(-1))[0].detach().cpu() search_head = torch.sigmoid(self.model.search_function_head(elem_tokens).squeeze(-1))[0].detach().cpu() output_is_summary = target_schema_is_summary(getattr(self.ckpt_args, "target_schema", "zh")) output_is_natural = target_schema_is_natural_text(getattr(self.ckpt_args, "target_schema", "zh")) if output_is_summary: pred_obj = prediction_from_summary(row, text) ok = True elif output_is_natural: pred_obj = natural_prediction_from_text(text) ok = bool(pred_obj.get("画面总结")) else: pred_obj, ok = safe_json_loads(text) if pred_obj is None: elements = row.get("ui_items", []) or [] ranked = torch.argsort(evidence_head, descending=True).tolist() top_ids = [] for idx in ranked: if idx < len(elements) and idx < len(masks) and masks[idx]: evidence_id = safe_text(elements[idx].get("id") or elements[idx].get("ocr_id")) if evidence_id: top_ids.append(evidence_id) if len(top_ids) >= self.args.top_k_clues: break pred_obj = {"关键证据": top_ids} if self.args.context_summary_repair: pred_obj, _ = repair_prediction_with_context(row, pred_obj) ok = True pred_obj = apply_structured_function_predictions(row, pred_obj, function_head, search_head, self.ckpt_args) pred_obj = apply_structured_evidence_predictions(row, pred_obj, evidence_head, self.ckpt_args) pred_obj = gui_ground_prediction(row, pred_obj, self.args) evidence_scores: Dict[str, float] = {} for idx, elem in enumerate(row.get("ui_items", []) or []): if idx < len(masks) and masks[idx]: evidence_id = safe_text(elem.get("id") or elem.get("ocr_id")) if evidence_id: evidence_scores[evidence_id] = float(evidence_head[idx]) result = row_result( row=row, raw_text=text, pred_obj=pred_obj, json_valid=ok, evidence_scores=evidence_scores, allow_template_fallback=self.args.allow_template_fallback, source="model", ) result["ocr_count"] = len(row.get("ocr_items", []) or []) result["ui_item_count"] = len(row.get("ui_items", []) or []) result["ocr_error"] = ocr_error result["image_url"] = f"/uploads/{image_path.name}" result["focus"] = safe_text(focus) if self.args.template_only: result["model_output_mode"] = "template" elif self.ckpt_args is not None and target_schema_is_summary(getattr(self.ckpt_args, "target_schema", "zh")): result["model_output_mode"] = "summary" elif self.ckpt_args is not None and target_schema_is_natural_text(getattr(self.ckpt_args, "target_schema", "zh")): result["model_output_mode"] = "natural_text" else: result["model_output_mode"] = "json" result["display_mode"] = self.args.gui_summary_mode if row.get("ocr_items") else "model" result["elapsed_sec"] = round(time.perf_counter() - start, 3) return result class RichGuiHandler(BaseHTTPRequestHandler): server_version = "RichGui/1.0" def send_payload(self, status: int, body: bytes, content_type: str) -> None: self.send_response(status) self.send_header("Content-Type", content_type) self.send_header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0") self.send_header("Pragma", "no-cache") self.send_header("Content-Length", str(len(body))) self.end_headers() self.wfile.write(body) def do_GET(self) -> None: # noqa: N802 parsed = urlparse(self.path) if parsed.path in {"/", "/index.html"}: self.send_payload(HTTPStatus.OK, HTML_PAGE.encode("utf-8"), "text/html; charset=utf-8") return if parsed.path == "/api/health": status, body, content_type = json_bytes( { "ok": True, "template_only": self.server.predictor.args.template_only, "checkpoint": None if self.server.predictor.args.template_only else self.server.predictor.args.checkpoint, "structured_function_mode": self.server.predictor.args.structured_function_mode, "structured_function_threshold": self.server.predictor.args.structured_function_threshold, "structured_search_threshold": self.server.predictor.args.structured_search_threshold, "structured_evidence_mode": self.server.predictor.args.structured_evidence_mode, "structured_evidence_threshold": self.server.predictor.args.structured_evidence_threshold, } ) self.send_payload(status, body, content_type) return if parsed.path.startswith("/uploads/"): filename = safe_upload_name(unquote(parsed.path.removeprefix("/uploads/"))) path = Path(self.server.predictor.args.upload_dir) / filename if path.exists() and path.is_file(): content_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" self.send_payload(HTTPStatus.OK, path.read_bytes(), content_type) else: status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND) self.send_payload(status, body, content_type) return status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND) self.send_payload(status, body, content_type) def do_POST(self) -> None: # noqa: N802 parsed = urlparse(self.path) if parsed.path != "/api/summarize": status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND) self.send_payload(status, body, content_type) return try: length = int(self.headers.get("Content-Length", "0")) if length > self.server.predictor.args.max_upload_mb * 1024 * 1024: raise ValueError(f"Image is larger than {self.server.predictor.args.max_upload_mb} MB.") fields = parse_multipart(self.headers, self.rfile.read(length)) upload = fields.get("image") if not isinstance(upload, dict): raise ValueError("Missing image field.") image_path = self.server.predictor.save_image(upload) app = safe_text(fields.get("app")) or "移动应用" focus = safe_text(fields.get("focus")) or safe_text(fields.get("instruction")) ocr_engine = safe_text(fields.get("ocr_engine")) or self.server.predictor.args.ocr_engine if ocr_engine not in {"none", "paddleocr"}: ocr_engine = self.server.predictor.args.ocr_engine result = self.server.predictor.summarize(image_path, app, focus, ocr_engine) status, body, content_type = json_bytes(result, HTTPStatus.OK) except Exception as exc: status, body, content_type = json_bytes({"error": str(exc)}, HTTPStatus.BAD_REQUEST) self.send_payload(status, body, content_type) def log_message(self, fmt: str, *args: Any) -> None: if not self.server.predictor.args.quiet: super().log_message(fmt, *args) class RichGuiServer(ThreadingHTTPServer): def __init__(self, server_address: Tuple[str, int], handler_class: Any, predictor: RichGuiPredictor): super().__init__(server_address, handler_class) self.predictor = predictor def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Start a local GUI for rich screenshot summarization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT) parser.add_argument("--template_only", action="store_true") parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--open_browser", type=str_to_bool, default=True) parser.add_argument("--quiet", action="store_true") parser.add_argument("--upload_dir", default="outputs/rich_gui/uploads") parser.add_argument("--ocr_cache_dir", default="data/rich_cmgui/cache/gui_ocr") parser.add_argument("--ocr_engine", choices=["none", "paddleocr"], default="paddleocr") parser.add_argument("--ocr_lang", default="ch") parser.add_argument("--min_ocr_conf", type=float, default=0.5) parser.add_argument("--max_ocr_items", type=int, default=120) parser.add_argument("--max_ui_items", type=int, default=48) parser.add_argument("--max_upload_mb", type=int, default=20) parser.add_argument("--device", default="") parser.add_argument("--num_beams", type=int, default=1) parser.add_argument("--max_new_tokens", type=int, default=256) parser.add_argument("--top_k_clues", type=int, default=5) parser.add_argument("--max_visible_text", type=int, default=12) parser.add_argument("--gui_summary_mode", choices=["model", "ocr", "auto"], default="model") parser.add_argument("--merge_ocr_functions", type=str_to_bool, default=True) parser.add_argument("--use_focus_in_model", type=str_to_bool, default=False) parser.add_argument("--generation_no_repeat_ngram_size", type=int, default=3) parser.add_argument("--generation_repetition_penalty", type=float, default=1.1) parser.add_argument("--generation_block_extra_ids", type=str_to_bool, default=True) parser.add_argument("--generation_block_title_prefix", type=str_to_bool, default=True) parser.add_argument("--generation_force_json_start", type=str_to_bool, default=False) parser.add_argument("--context_summary_repair", type=str_to_bool, default=None) parser.add_argument("--canonicalize_targets", type=str_to_bool, default=None) parser.add_argument("--drop_bare_search_functions", type=str_to_bool, default=None) parser.add_argument("--structured_function_mode", choices=["", "decoder", "heads"], default="heads") parser.add_argument("--structured_function_threshold", type=float, default=DEFAULT_FUNCTION_THRESHOLD) parser.add_argument("--structured_search_threshold", type=float, default=DEFAULT_SEARCH_THRESHOLD) parser.add_argument("--structured_max_functions", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS) parser.add_argument("--structured_strict_search_candidates", type=str_to_bool, default=None) parser.add_argument("--structured_evidence_mode", choices=["", "decoder", "heads"], default="heads") parser.add_argument("--structured_evidence_threshold", type=float, default=DEFAULT_EVIDENCE_THRESHOLD) parser.add_argument("--structured_max_evidence", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS) parser.add_argument("--structured_evidence_fallback_top1", type=str_to_bool, default=False) parser.add_argument("--allow_template_fallback", type=str_to_bool, default=False) args = parser.parse_args() if not args.template_only and not args.checkpoint: checkpoint = find_latest_rich_checkpoint() if checkpoint is None: raise FileNotFoundError("No stage3/stage4 rich checkpoint found. Train first or pass --checkpoint.") args.checkpoint = str(checkpoint) return args def main() -> None: args = parse_args() predictor = RichGuiPredictor(args) last_error: Optional[OSError] = None server: Optional[RichGuiServer] = None for port in range(args.port, args.port + 20): try: server = RichGuiServer((args.host, port), RichGuiHandler, predictor) break except OSError as exc: last_error = exc if server is None: raise RuntimeError(f"Could not bind a local port starting at {args.port}: {last_error}") actual_port = server.server_address[1] url = f"http://{args.host}:{actual_port}/" print( json.dumps( { "url": url, "template_only": args.template_only, "checkpoint": None if args.template_only else args.checkpoint, "num_beams": args.num_beams, "max_new_tokens": args.max_new_tokens, "structured_function_mode": args.structured_function_mode, "structured_function_threshold": args.structured_function_threshold, "structured_search_threshold": args.structured_search_threshold, "structured_evidence_mode": args.structured_evidence_mode, "structured_evidence_threshold": args.structured_evidence_threshold, }, ensure_ascii=False, ) ) if args.open_browser: threading.Timer(0.5, lambda: webbrowser.open(url)).start() server.serve_forever() if __name__ == "__main__": main()