#!/usr/bin/env python """Local web GUI for rich CMGUI screenshot summarization.""" from __future__ import annotations import argparse import json import mimetypes import re import sys import threading import time import webbrowser from http import HTTPStatus from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer from io import BytesIO from pathlib import Path from typing import Any, Dict, List, Optional, Tuple from urllib.parse import unquote, urlparse import torch from PIL import Image SCRIPT_DIR = Path(__file__).resolve().parent if str(SCRIPT_DIR) not in sys.path: sys.path.insert(0, str(SCRIPT_DIR)) from enrich_rich_ocr_evidence import build_ocr_ui_items, filter_ocr_items # noqa: E402 from infer_rich import find_latest_rich_checkpoint, row_result, template_prediction # noqa: E402 from prepare_rich_data import load_ocr_items, safe_text, sha256_file # noqa: E402 from train_rich import ( # noqa: E402 RichCollator, apply_structured_evidence_predictions, apply_structured_function_predictions, load_rich_checkpoint, move_batch, natural_prediction_from_text, prediction_from_summary, repair_prediction_with_context, safe_json_loads, target_schema_is_natural_text, target_schema_is_summary, ) DEFAULT_CHECKPOINT = "" DEFAULT_FUNCTION_THRESHOLD = 0.20 DEFAULT_SEARCH_THRESHOLD = 0.20 DEFAULT_EVIDENCE_THRESHOLD = 0.50 DEFAULT_MAX_STRUCTURED_ITEMS = 8 ALLOWED_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"} GENERIC_APP_NAMES = {"", "移动应用", "手机应用", "应用", "App", "APP", "app"} SUMMARY_SKIP_TEXTS = { "ADB Keyboard {ON}", "反馈", "共享中", "分享中", } GUI_FUNCTION_KEYWORDS = ["确定", "取消", "删除", "分享", "评论", "收藏", "点赞", "关注", "返回", "首页", "消息", "购物车", "购买", "下单", "关闭"] PASSIVE_SEARCH_TERMS = ["历史搜索", "搜索历史", "热门搜索", "热搜", "猜你想搜", "搜索推荐", "搜索记录"] HTML_PAGE = r"""
\"?)([^\";]+)(?P=q)", content_type) if not match: raise ValueError("Missing multipart boundary.") boundary = ("--" + match.group(2)).encode("utf-8") fields: Dict[str, Any] = {} for part in body.split(boundary): part = part.strip(b"\r\n") if not part or part == b"--": continue header_blob, _, value = part.partition(b"\r\n\r\n") if not header_blob: continue header_text = header_blob.decode("utf-8", errors="replace") disposition = "" for line in header_text.split("\r\n"): if line.lower().startswith("content-disposition:"): disposition = line.split(":", 1)[1] break name_match = re.search(r'name="([^"]+)"', disposition) if not name_match: continue field_name = name_match.group(1) filename_match = re.search(r'filename="([^"]*)"', disposition) if filename_match: fields[field_name] = { "filename": filename_match.group(1), "content": value.rstrip(b"\r\n"), } else: fields[field_name] = value.rstrip(b"\r\n").decode("utf-8", errors="replace") return fields class RichGuiPredictor: def __init__(self, args: argparse.Namespace): self.args = args self.device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu")) self.model = None self.tokenizer = None self.image_processor = None self.ckpt_args = None self.collator = None self.lock = threading.Lock() def load_model(self) -> None: if self.args.template_only or self.model is not None: return if not self.args.checkpoint: raise FileNotFoundError("Checkpoint not set. Train a natural multimodal checkpoint first or pass --checkpoint.") checkpoint = Path(self.args.checkpoint) if not checkpoint.exists(): raise FileNotFoundError(f"Checkpoint not found: {checkpoint}") self.model, self.tokenizer, self.image_processor, self.ckpt_args = load_rich_checkpoint(str(checkpoint), self.device) self.apply_runtime_args(self.ckpt_args) self.collator = RichCollator(self.tokenizer, self.image_processor, self.ckpt_args) def apply_runtime_args(self, ckpt_args: argparse.Namespace) -> None: optional_names = [ "generation_no_repeat_ngram_size", "generation_repetition_penalty", "generation_block_extra_ids", "generation_block_title_prefix", "generation_force_json_start", "context_summary_repair", "canonicalize_targets", "drop_bare_search_functions", "structured_function_threshold", "structured_search_threshold", "structured_max_functions", "structured_strict_search_candidates", "structured_evidence_threshold", "structured_max_evidence", "structured_evidence_fallback_top1", ] for name in optional_names: value = getattr(self.args, name, None) if value is not None and value != "": setattr(ckpt_args, name, value) for name in ["structured_function_mode", "structured_evidence_mode"]: value = getattr(self.args, name, "") if value: setattr(ckpt_args, name, value) ckpt_args.num_workers = 0 def save_image(self, upload: Dict[str, Any]) -> Path: raw = upload.get("content") or b"" if not raw: raise ValueError("Uploaded image is empty.") filename = safe_upload_name(upload.get("filename", "upload.png")) suffix = Path(filename).suffix.lower() if suffix and suffix not in ALLOWED_IMAGE_EXTS: raise ValueError(f"Unsupported image type: {suffix}") upload_dir = Path(self.args.upload_dir) upload_dir.mkdir(parents=True, exist_ok=True) with Image.open(BytesIO(raw)) as img: image = img.convert("RGB") stamp = time.strftime("%Y%m%d_%H%M%S") upload_path = upload_dir / f"{stamp}_{int(time.time() * 1000) % 1000:03d}_{Path(filename).stem}.png" image.save(upload_path) return upload_path def build_row(self, image_path: Path, app: str, focus: str, ocr_engine: str) -> Tuple[Dict[str, Any], Optional[str]]: model_instruction = safe_text(focus) if self.args.use_focus_in_model else "" row: Dict[str, Any] = { "screen_id": image_path.stem, "image_path": str(image_path), "app": safe_text(app) or "移动应用", "instruction": model_instruction, "display_focus": safe_text(focus), "target": { "summary_zh": "", "visible_text": [], "interaction_data": [], "ui_functions": [], "key_ui_clues": [], }, "ocr_items": [], "ui_items": [], "weak_evidence_ids": [], } ocr_error = None if ocr_engine != "none": try: cache_args = argparse.Namespace(ocr_engine=ocr_engine, ocr_lang=self.args.ocr_lang) image_sha = sha256_file(image_path) ocr_items = load_ocr_items(image_path, image_sha, Path(self.args.ocr_cache_dir), cache_args) ocr_items = filter_ocr_items(ocr_items, self.args.min_ocr_conf)[: self.args.max_ocr_items] row["ocr_items"] = ocr_items row["ui_items"] = build_ocr_ui_items(row, ocr_items, self.args.max_ui_items) row["weak_evidence_ids"] = [item.get("id") for item in row["ui_items"][:8] if item.get("id")] except Exception as exc: # OCR should not prevent image-only inference. ocr_error = str(exc) return row, ocr_error @torch.no_grad() def summarize(self, image_path: Path, app: str, focus: str, ocr_engine: str) -> Dict[str, Any]: start = time.perf_counter() row, ocr_error = self.build_row(image_path, app, focus, ocr_engine) if self.args.template_only: pred = template_prediction(row, max_visible=self.args.max_visible_text) pred = gui_ground_prediction(row, pred, self.args) result = row_result( row=row, raw_text=json.dumps(pred, ensure_ascii=False), pred_obj=pred, json_valid=True, evidence_scores=None, allow_template_fallback=False, source="template", ) else: with self.lock: self.load_model() batch = self.collator([row]) batch = move_batch(batch, self.device) text = self.model.generate_text( batch, self.tokenizer, num_beams=self.args.num_beams, max_new_tokens=self.args.max_new_tokens, )[0] _, _, elem_tokens, elem_key_padding = self.model.build_memory(batch) masks = (~elem_key_padding)[0].detach().cpu().numpy() evidence_head = torch.sigmoid(self.model.evidence_head(elem_tokens).squeeze(-1))[0].detach().cpu() function_head = torch.sigmoid(self.model.ui_function_head(elem_tokens).squeeze(-1))[0].detach().cpu() search_head = torch.sigmoid(self.model.search_function_head(elem_tokens).squeeze(-1))[0].detach().cpu() output_is_summary = target_schema_is_summary(getattr(self.ckpt_args, "target_schema", "zh")) output_is_natural = target_schema_is_natural_text(getattr(self.ckpt_args, "target_schema", "zh")) if output_is_summary: pred_obj = prediction_from_summary(row, text) ok = True elif output_is_natural: pred_obj = natural_prediction_from_text(text) ok = bool(pred_obj.get("画面总结")) else: pred_obj, ok = safe_json_loads(text) if pred_obj is None: elements = row.get("ui_items", []) or [] ranked = torch.argsort(evidence_head, descending=True).tolist() top_ids = [] for idx in ranked: if idx < len(elements) and idx < len(masks) and masks[idx]: evidence_id = safe_text(elements[idx].get("id") or elements[idx].get("ocr_id")) if evidence_id: top_ids.append(evidence_id) if len(top_ids) >= self.args.top_k_clues: break pred_obj = {"关键证据": top_ids} if self.args.context_summary_repair: pred_obj, _ = repair_prediction_with_context(row, pred_obj) ok = True pred_obj = apply_structured_function_predictions(row, pred_obj, function_head, search_head, self.ckpt_args) pred_obj = apply_structured_evidence_predictions(row, pred_obj, evidence_head, self.ckpt_args) pred_obj = gui_ground_prediction(row, pred_obj, self.args) evidence_scores: Dict[str, float] = {} for idx, elem in enumerate(row.get("ui_items", []) or []): if idx < len(masks) and masks[idx]: evidence_id = safe_text(elem.get("id") or elem.get("ocr_id")) if evidence_id: evidence_scores[evidence_id] = float(evidence_head[idx]) result = row_result( row=row, raw_text=text, pred_obj=pred_obj, json_valid=ok, evidence_scores=evidence_scores, allow_template_fallback=self.args.allow_template_fallback, source="model", ) result["ocr_count"] = len(row.get("ocr_items", []) or []) result["ui_item_count"] = len(row.get("ui_items", []) or []) result["ocr_error"] = ocr_error result["image_url"] = f"/uploads/{image_path.name}" result["focus"] = safe_text(focus) if self.args.template_only: result["model_output_mode"] = "template" elif self.ckpt_args is not None and target_schema_is_summary(getattr(self.ckpt_args, "target_schema", "zh")): result["model_output_mode"] = "summary" elif self.ckpt_args is not None and target_schema_is_natural_text(getattr(self.ckpt_args, "target_schema", "zh")): result["model_output_mode"] = "natural_text" else: result["model_output_mode"] = "json" result["display_mode"] = self.args.gui_summary_mode if row.get("ocr_items") else "model" result["elapsed_sec"] = round(time.perf_counter() - start, 3) return result class RichGuiHandler(BaseHTTPRequestHandler): server_version = "RichGui/1.0" def send_payload(self, status: int, body: bytes, content_type: str) -> None: self.send_response(status) self.send_header("Content-Type", content_type) self.send_header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0") self.send_header("Pragma", "no-cache") self.send_header("Content-Length", str(len(body))) self.end_headers() self.wfile.write(body) def do_GET(self) -> None: # noqa: N802 parsed = urlparse(self.path) if parsed.path in {"/", "/index.html"}: self.send_payload(HTTPStatus.OK, HTML_PAGE.encode("utf-8"), "text/html; charset=utf-8") return if parsed.path == "/api/health": status, body, content_type = json_bytes( { "ok": True, "template_only": self.server.predictor.args.template_only, "checkpoint": None if self.server.predictor.args.template_only else self.server.predictor.args.checkpoint, "structured_function_mode": self.server.predictor.args.structured_function_mode, "structured_function_threshold": self.server.predictor.args.structured_function_threshold, "structured_search_threshold": self.server.predictor.args.structured_search_threshold, "structured_evidence_mode": self.server.predictor.args.structured_evidence_mode, "structured_evidence_threshold": self.server.predictor.args.structured_evidence_threshold, } ) self.send_payload(status, body, content_type) return if parsed.path.startswith("/uploads/"): filename = safe_upload_name(unquote(parsed.path.removeprefix("/uploads/"))) path = Path(self.server.predictor.args.upload_dir) / filename if path.exists() and path.is_file(): content_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream" self.send_payload(HTTPStatus.OK, path.read_bytes(), content_type) else: status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND) self.send_payload(status, body, content_type) return status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND) self.send_payload(status, body, content_type) def do_POST(self) -> None: # noqa: N802 parsed = urlparse(self.path) if parsed.path != "/api/summarize": status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND) self.send_payload(status, body, content_type) return try: length = int(self.headers.get("Content-Length", "0")) if length > self.server.predictor.args.max_upload_mb * 1024 * 1024: raise ValueError(f"Image is larger than {self.server.predictor.args.max_upload_mb} MB.") fields = parse_multipart(self.headers, self.rfile.read(length)) upload = fields.get("image") if not isinstance(upload, dict): raise ValueError("Missing image field.") image_path = self.server.predictor.save_image(upload) app = safe_text(fields.get("app")) or "移动应用" focus = safe_text(fields.get("focus")) or safe_text(fields.get("instruction")) ocr_engine = safe_text(fields.get("ocr_engine")) or self.server.predictor.args.ocr_engine if ocr_engine not in {"none", "paddleocr"}: ocr_engine = self.server.predictor.args.ocr_engine result = self.server.predictor.summarize(image_path, app, focus, ocr_engine) status, body, content_type = json_bytes(result, HTTPStatus.OK) except Exception as exc: status, body, content_type = json_bytes({"error": str(exc)}, HTTPStatus.BAD_REQUEST) self.send_payload(status, body, content_type) def log_message(self, fmt: str, *args: Any) -> None: if not self.server.predictor.args.quiet: super().log_message(fmt, *args) class RichGuiServer(ThreadingHTTPServer): def __init__(self, server_address: Tuple[str, int], handler_class: Any, predictor: RichGuiPredictor): super().__init__(server_address, handler_class) self.predictor = predictor def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="Start a local GUI for rich screenshot summarization.", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) parser.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT) parser.add_argument("--template_only", action="store_true") parser.add_argument("--host", default="127.0.0.1") parser.add_argument("--port", type=int, default=7860) parser.add_argument("--open_browser", type=str_to_bool, default=True) parser.add_argument("--quiet", action="store_true") parser.add_argument("--upload_dir", default="outputs/rich_gui/uploads") parser.add_argument("--ocr_cache_dir", default="data/rich_cmgui/cache/gui_ocr") parser.add_argument("--ocr_engine", choices=["none", "paddleocr"], default="paddleocr") parser.add_argument("--ocr_lang", default="ch") parser.add_argument("--min_ocr_conf", type=float, default=0.5) parser.add_argument("--max_ocr_items", type=int, default=120) parser.add_argument("--max_ui_items", type=int, default=48) parser.add_argument("--max_upload_mb", type=int, default=20) parser.add_argument("--device", default="") parser.add_argument("--num_beams", type=int, default=1) parser.add_argument("--max_new_tokens", type=int, default=256) parser.add_argument("--top_k_clues", type=int, default=5) parser.add_argument("--max_visible_text", type=int, default=12) parser.add_argument("--gui_summary_mode", choices=["model", "ocr", "auto"], default="model") parser.add_argument("--merge_ocr_functions", type=str_to_bool, default=True) parser.add_argument("--use_focus_in_model", type=str_to_bool, default=False) parser.add_argument("--generation_no_repeat_ngram_size", type=int, default=3) parser.add_argument("--generation_repetition_penalty", type=float, default=1.1) parser.add_argument("--generation_block_extra_ids", type=str_to_bool, default=True) parser.add_argument("--generation_block_title_prefix", type=str_to_bool, default=True) parser.add_argument("--generation_force_json_start", type=str_to_bool, default=False) parser.add_argument("--context_summary_repair", type=str_to_bool, default=None) parser.add_argument("--canonicalize_targets", type=str_to_bool, default=None) parser.add_argument("--drop_bare_search_functions", type=str_to_bool, default=None) parser.add_argument("--structured_function_mode", choices=["", "decoder", "heads"], default="heads") parser.add_argument("--structured_function_threshold", type=float, default=DEFAULT_FUNCTION_THRESHOLD) parser.add_argument("--structured_search_threshold", type=float, default=DEFAULT_SEARCH_THRESHOLD) parser.add_argument("--structured_max_functions", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS) parser.add_argument("--structured_strict_search_candidates", type=str_to_bool, default=None) parser.add_argument("--structured_evidence_mode", choices=["", "decoder", "heads"], default="heads") parser.add_argument("--structured_evidence_threshold", type=float, default=DEFAULT_EVIDENCE_THRESHOLD) parser.add_argument("--structured_max_evidence", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS) parser.add_argument("--structured_evidence_fallback_top1", type=str_to_bool, default=False) parser.add_argument("--allow_template_fallback", type=str_to_bool, default=False) args = parser.parse_args() if not args.template_only and not args.checkpoint: checkpoint = find_latest_rich_checkpoint() if checkpoint is None: raise FileNotFoundError("No stage3/stage4 rich checkpoint found. Train first or pass --checkpoint.") args.checkpoint = str(checkpoint) return args def main() -> None: args = parse_args() predictor = RichGuiPredictor(args) last_error: Optional[OSError] = None server: Optional[RichGuiServer] = None for port in range(args.port, args.port + 20): try: server = RichGuiServer((args.host, port), RichGuiHandler, predictor) break except OSError as exc: last_error = exc if server is None: raise RuntimeError(f"Could not bind a local port starting at {args.port}: {last_error}") actual_port = server.server_address[1] url = f"http://{args.host}:{actual_port}/" print( json.dumps( { "url": url, "template_only": args.template_only, "checkpoint": None if args.template_only else args.checkpoint, "num_beams": args.num_beams, "max_new_tokens": args.max_new_tokens, "structured_function_mode": args.structured_function_mode, "structured_function_threshold": args.structured_function_threshold, "structured_search_threshold": args.structured_search_threshold, "structured_evidence_mode": args.structured_evidence_mode, "structured_evidence_threshold": args.structured_evidence_threshold, }, ensure_ascii=False, ) ) if args.open_browser: threading.Timer(0.5, lambda: webbrowser.open(url)).start() server.serve_forever() if __name__ == "__main__": main()