|
|
| """Local web GUI for rich CMGUI screenshot summarization."""
|
|
|
| from __future__ import annotations
|
|
|
| import argparse
|
| import json
|
| import mimetypes
|
| import re
|
| import sys
|
| import threading
|
| import time
|
| import webbrowser
|
| from http import HTTPStatus
|
| from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
|
| from io import BytesIO
|
| from pathlib import Path
|
| from typing import Any, Dict, List, Optional, Tuple
|
| from urllib.parse import unquote, urlparse
|
|
|
| import torch
|
| from PIL import Image
|
|
|
| SCRIPT_DIR = Path(__file__).resolve().parent
|
| if str(SCRIPT_DIR) not in sys.path:
|
| sys.path.insert(0, str(SCRIPT_DIR))
|
|
|
| from enrich_rich_ocr_evidence import build_ocr_ui_items, filter_ocr_items
|
| from infer_rich import find_latest_rich_checkpoint, row_result, template_prediction |
| from prepare_rich_data import load_ocr_items, safe_text, sha256_file
|
| from train_rich import ( |
| RichCollator, |
| apply_structured_evidence_predictions, |
| apply_structured_function_predictions, |
| load_rich_checkpoint,
|
| move_batch,
|
| natural_prediction_from_text,
|
| prediction_from_summary,
|
| repair_prediction_with_context,
|
| safe_json_loads,
|
| target_schema_is_natural_text,
|
| target_schema_is_summary,
|
| )
|
|
|
|
|
| DEFAULT_CHECKPOINT = "" |
| DEFAULT_FUNCTION_THRESHOLD = 0.20 |
| DEFAULT_SEARCH_THRESHOLD = 0.20 |
| DEFAULT_EVIDENCE_THRESHOLD = 0.50 |
| DEFAULT_MAX_STRUCTURED_ITEMS = 8 |
|
|
| ALLOWED_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
|
|
|
| GENERIC_APP_NAMES = {"", "移动应用", "手机应用", "应用", "App", "APP", "app"}
|
| SUMMARY_SKIP_TEXTS = {
|
| "ADB Keyboard {ON}",
|
| "反馈",
|
| "共享中",
|
| "分享中",
|
| }
|
| GUI_FUNCTION_KEYWORDS = ["确定", "取消", "删除", "分享", "评论", "收藏", "点赞", "关注", "返回", "首页", "消息", "购物车", "购买", "下单", "关闭"]
|
| PASSIVE_SEARCH_TERMS = ["历史搜索", "搜索历史", "热门搜索", "热搜", "猜你想搜", "搜索推荐", "搜索记录"]
|
|
|
|
|
| HTML_PAGE = r"""
|
| <!doctype html>
|
| <html lang="zh-CN">
|
| <head>
|
| <meta charset="utf-8" />
|
| <meta name="viewport" content="width=device-width, initial-scale=1" />
|
| <title>Rich Screenshot Summarizer</title>
|
| <style>
|
| :root {
|
| color-scheme: light;
|
| --bg: #f6f7f9;
|
| --panel: #ffffff;
|
| --line: #d8dde6;
|
| --text: #1d2433;
|
| --muted: #5e6a7d;
|
| --accent: #166a5a;
|
| --accent-2: #b35c16;
|
| --danger: #b3261e;
|
| --shadow: 0 10px 30px rgba(29, 36, 51, 0.08);
|
| }
|
| * { box-sizing: border-box; }
|
| body {
|
| margin: 0;
|
| min-height: 100vh;
|
| background: var(--bg);
|
| color: var(--text);
|
| font-family: "Segoe UI", "Microsoft YaHei", Arial, sans-serif;
|
| letter-spacing: 0;
|
| }
|
| header {
|
| border-bottom: 1px solid var(--line);
|
| background: rgba(255, 255, 255, 0.92);
|
| position: sticky;
|
| top: 0;
|
| z-index: 10;
|
| backdrop-filter: blur(10px);
|
| }
|
| .bar {
|
| max-width: 1180px;
|
| margin: 0 auto;
|
| padding: 14px 20px;
|
| display: flex;
|
| align-items: center;
|
| justify-content: space-between;
|
| gap: 16px;
|
| }
|
| h1 {
|
| font-size: 18px;
|
| line-height: 1.25;
|
| margin: 0;
|
| font-weight: 650;
|
| }
|
| .status {
|
| color: var(--muted);
|
| font-size: 13px;
|
| white-space: nowrap;
|
| }
|
| main {
|
| max-width: 1180px;
|
| margin: 0 auto;
|
| padding: 22px 20px 32px;
|
| display: grid;
|
| grid-template-columns: 360px minmax(0, 1fr);
|
| gap: 18px;
|
| }
|
| section {
|
| background: var(--panel);
|
| border: 1px solid var(--line);
|
| border-radius: 8px;
|
| box-shadow: var(--shadow);
|
| }
|
| .controls { padding: 18px; }
|
| .preview { padding: 14px; }
|
| label {
|
| display: block;
|
| color: var(--muted);
|
| font-size: 13px;
|
| margin: 0 0 7px;
|
| }
|
| input[type="file"], input[type="text"], textarea, select {
|
| width: 100%;
|
| border: 1px solid var(--line);
|
| border-radius: 6px;
|
| padding: 10px 11px;
|
| font: inherit;
|
| font-size: 14px;
|
| background: #fff;
|
| color: var(--text);
|
| min-height: 40px;
|
| }
|
| textarea { resize: vertical; min-height: 76px; }
|
| .field { margin-bottom: 14px; }
|
| .row { display: grid; grid-template-columns: 1fr 1fr; gap: 10px; }
|
| button {
|
| width: 100%;
|
| border: 1px solid #10584a;
|
| border-radius: 6px;
|
| background: var(--accent);
|
| color: #fff;
|
| padding: 11px 14px;
|
| font: inherit;
|
| font-size: 14px;
|
| font-weight: 650;
|
| cursor: pointer;
|
| min-height: 42px;
|
| }
|
| button:disabled { cursor: wait; opacity: 0.72; }
|
| .image-box {
|
| border: 1px solid var(--line);
|
| border-radius: 8px;
|
| background: #eef1f5;
|
| overflow: hidden;
|
| min-height: 260px;
|
| display: grid;
|
| place-items: center;
|
| }
|
| .image-box img {
|
| max-width: 100%;
|
| max-height: 72vh;
|
| display: block;
|
| object-fit: contain;
|
| }
|
| .placeholder { color: var(--muted); font-size: 14px; padding: 30px; text-align: center; }
|
| .result { padding: 18px; display: grid; gap: 16px; }
|
| .summary {
|
| font-size: 17px;
|
| line-height: 1.65;
|
| padding-bottom: 8px;
|
| border-bottom: 1px solid var(--line);
|
| }
|
| .grid {
|
| display: grid;
|
| grid-template-columns: repeat(2, minmax(0, 1fr));
|
| gap: 14px;
|
| }
|
| .block h2 {
|
| margin: 0 0 8px;
|
| font-size: 14px;
|
| line-height: 1.35;
|
| color: var(--accent);
|
| }
|
| ul { margin: 0; padding-left: 18px; }
|
| li { margin: 4px 0; line-height: 1.45; overflow-wrap: anywhere; }
|
| .meta {
|
| display: flex;
|
| flex-wrap: wrap;
|
| gap: 8px;
|
| color: var(--muted);
|
| font-size: 12px;
|
| }
|
| .pill {
|
| border: 1px solid var(--line);
|
| border-radius: 999px;
|
| padding: 3px 8px;
|
| background: #fafbfc;
|
| }
|
| details {
|
| border-top: 1px solid var(--line);
|
| padding-top: 12px;
|
| }
|
| summary { cursor: pointer; color: var(--accent-2); font-size: 13px; }
|
| pre {
|
| white-space: pre-wrap;
|
| overflow-wrap: anywhere;
|
| background: #f3f5f7;
|
| border: 1px solid var(--line);
|
| border-radius: 6px;
|
| padding: 12px;
|
| max-height: 360px;
|
| overflow: auto;
|
| font-size: 12px;
|
| }
|
| .error {
|
| color: var(--danger);
|
| border: 1px solid rgba(179, 38, 30, 0.35);
|
| background: rgba(179, 38, 30, 0.06);
|
| border-radius: 6px;
|
| padding: 11px;
|
| line-height: 1.45;
|
| }
|
| @media (max-width: 860px) {
|
| main { grid-template-columns: 1fr; }
|
| .grid { grid-template-columns: 1fr; }
|
| .bar { align-items: flex-start; flex-direction: column; }
|
| .status { white-space: normal; }
|
| }
|
| </style>
|
| </head>
|
| <body>
|
| <header>
|
| <div class="bar">
|
| <h1>Rich Screenshot Summarizer</h1>
|
| <div id="status" class="status">就绪</div>
|
| </div>
|
| </header>
|
| <main>
|
| <section class="controls">
|
| <form id="form">
|
| <div class="field">
|
| <label for="image">图片</label>
|
| <input id="image" name="image" type="file" accept="image/*" required />
|
| </div>
|
| <div class="row">
|
| <div class="field">
|
| <label for="app">应用名/类型(可选)</label>
|
| <input id="app" name="app" type="text" placeholder="可留空" />
|
| </div>
|
| <div class="field">
|
| <label for="ocr_engine">OCR</label>
|
| <select id="ocr_engine" name="ocr_engine">
|
| <option value="paddleocr" selected>paddleocr</option>
|
| <option value="none">none</option>
|
| </select>
|
| </div>
|
| </div>
|
| <div class="field">
|
| <label for="focus">关注点(可选)</label>
|
| <textarea id="focus" name="focus" placeholder="可留空,例如:搜索结果、价格信息、按钮入口"></textarea>
|
| </div>
|
| <button id="submit" type="submit">生成总结</button>
|
| </form>
|
| </section>
|
| <section class="preview">
|
| <div class="image-box" id="imageBox"><div class="placeholder">未选择图片</div></div>
|
| </section>
|
| <section class="result" style="grid-column: 1 / -1;">
|
| <div id="output" class="placeholder">结果会显示在这里</div>
|
| </section>
|
| </main>
|
| <script>
|
| const form = document.getElementById('form');
|
| const imageInput = document.getElementById('image');
|
| const imageBox = document.getElementById('imageBox');
|
| const output = document.getElementById('output');
|
| const statusEl = document.getElementById('status');
|
| const submit = document.getElementById('submit');
|
|
|
| imageInput.addEventListener('change', () => {
|
| const file = imageInput.files && imageInput.files[0];
|
| if (!file) {
|
| imageBox.innerHTML = '<div class="placeholder">未选择图片</div>';
|
| return;
|
| }
|
| const url = URL.createObjectURL(file);
|
| imageBox.innerHTML = '';
|
| const img = document.createElement('img');
|
| img.src = url;
|
| img.onload = () => URL.revokeObjectURL(url);
|
| imageBox.appendChild(img);
|
| });
|
|
|
| function escapeHtml(value) {
|
| return String(value ?? '').replace(/[&<>'"]/g, ch => ({
|
| '&': '&', '<': '<', '>': '>', "'": ''', '"': '"'
|
| }[ch]));
|
| }
|
|
|
| function listItems(items, mapFn) {
|
| if (!Array.isArray(items) || items.length === 0) return '<span class="placeholder">无</span>';
|
| return '<ul>' + items.map(item => '<li>' + mapFn(item) + '</li>').join('') + '</ul>';
|
| }
|
|
|
| function renderResult(data) {
|
| const pred = data.prediction || {};
|
| const outputMode = data.model_output_mode || 'json';
|
| const parseLabel = outputMode === 'summary' ? '输出 summary' : (outputMode === 'natural_text' ? ('自然文本 ' + (data.json_valid ? 'parsed' : 'fallback')) : (outputMode === 'template' ? '模板输出' : ('JSON ' + (data.json_valid ? 'valid' : 'fallback'))));
|
| const visible = listItems(pred['可见文字'], x => escapeHtml(x));
|
| const funcs = listItems(pred['功能入口'], item => {
|
| const evidence = Array.isArray(item.evidence_ids) ? item.evidence_ids.join(', ') : '';
|
| return escapeHtml(item.name || '') + (evidence ? ' <span class="pill">' + escapeHtml(evidence) + '</span>' : '');
|
| });
|
| const interactions = listItems(pred['互动数据'], item => {
|
| const value = item.value ? ':' + item.value : '';
|
| return escapeHtml((item.name || '') + value);
|
| });
|
| const clues = listItems(data.key_ui_clues, item => {
|
| const score = typeof item.score === 'number' ? ' score=' + item.score.toFixed(3) : '';
|
| const text = item.text ? ' ' + item.text : '';
|
| return '<span class="pill">' + escapeHtml(item.element_id || '') + '</span>' + escapeHtml(text + score);
|
| });
|
| output.className = '';
|
| output.innerHTML = `
|
| <div class="summary">${escapeHtml(data.summary || pred['画面总结'] || '')}</div>
|
| <div class="meta">
|
| <span class="pill">${escapeHtml(data.source || '')}</span>
|
| <span class="pill">${escapeHtml(outputMode)}</span>
|
| <span class="pill">${escapeHtml(data.display_mode || 'model')}</span>
|
| <span class="pill">${escapeHtml(parseLabel)}</span>
|
| <span class="pill">OCR ${data.ocr_count ?? 0}</span>
|
| <span class="pill">${Number(data.elapsed_sec || 0).toFixed(2)}s</span>
|
| </div>
|
| <div class="grid">
|
| <div class="block"><h2>可见文字</h2>${visible}</div>
|
| <div class="block"><h2>功能入口</h2>${funcs}</div>
|
| <div class="block"><h2>互动数据</h2>${interactions}</div>
|
| <div class="block"><h2>关键证据</h2>${clues}</div>
|
| </div>
|
| <details><summary>原始 JSON</summary><pre>${escapeHtml(JSON.stringify(data, null, 2))}</pre></details>
|
| `;
|
| if (data.image_url) {
|
| imageBox.innerHTML = `<img src="${escapeHtml(data.image_url)}" alt="uploaded screenshot" />`;
|
| }
|
| }
|
|
|
| form.addEventListener('submit', async event => {
|
| event.preventDefault();
|
| const formData = new FormData(form);
|
| submit.disabled = true;
|
| statusEl.textContent = '生成中';
|
| output.className = 'placeholder';
|
| output.textContent = '正在处理图片';
|
| try {
|
| const response = await fetch('/api/summarize', { method: 'POST', body: formData });
|
| const data = await response.json();
|
| if (!response.ok) throw new Error(data.error || '请求失败');
|
| renderResult(data);
|
| statusEl.textContent = '完成';
|
| } catch (error) {
|
| output.className = 'error';
|
| output.textContent = error.message || String(error);
|
| statusEl.textContent = '出错';
|
| } finally {
|
| submit.disabled = false;
|
| }
|
| });
|
| </script>
|
| </body>
|
| </html>
|
| """
|
|
|
|
|
| def str_to_bool(value: Any) -> bool:
|
| return str(value).lower() in {"1", "true", "yes", "y"}
|
|
|
|
|
| def json_bytes(obj: Dict[str, Any], status: int = 200) -> Tuple[int, bytes, str]:
|
| return status, json.dumps(obj, ensure_ascii=False).encode("utf-8"), "application/json; charset=utf-8"
|
|
|
|
|
| def safe_upload_name(filename: str) -> str:
|
| name = Path(filename or "upload").name
|
| name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._")
|
| return name or "upload"
|
|
|
|
|
| def display_app_name(app: str) -> str:
|
| text = safe_text(app)
|
| return "" if text in GENERIC_APP_NAMES else text
|
|
|
|
|
| def text_bbox(item: Dict[str, Any]) -> Tuple[float, float]:
|
| bbox = item.get("bbox") or []
|
| if isinstance(bbox, list) and len(bbox) >= 4:
|
| try:
|
| return float(bbox[1]), float(bbox[0])
|
| except (TypeError, ValueError):
|
| return 0.0, 0.0
|
| return 0.0, 0.0
|
|
|
|
|
| def skip_summary_text(text: str) -> bool:
|
| if not text:
|
| return True
|
| if text in SUMMARY_SKIP_TEXTS:
|
| return True
|
| if re.fullmatch(r"[0-9::/ ._\-]+", text):
|
| return True
|
| if re.fullmatch(r"\d{1,2}:\d{2}", text):
|
| return True
|
| if re.fullmatch(r"[A-Za-z0-9%+\- ]{1,5}", text):
|
| return True
|
| if len(text) == 1 and text not in {"搜"}:
|
| return True
|
| if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z]", text):
|
| return True
|
| return False
|
|
|
|
|
| def center_dialog_texts(row: Dict[str, Any]) -> List[str]:
|
| items: List[Tuple[float, float, str]] = []
|
| for item in row.get("ocr_items", []) or []:
|
| text = safe_text(item.get("text"))[:80]
|
| if not text or skip_summary_text(text):
|
| continue
|
| bbox = item.get("bbox") or []
|
| if not (isinstance(bbox, list) and len(bbox) == 4):
|
| continue
|
| try:
|
| x1, y1, x2, y2 = [float(value) for value in bbox]
|
| except (TypeError, ValueError):
|
| continue
|
| cx = (x1 + x2) / 2.0
|
| cy = (y1 + y2) / 2.0
|
| if 0.18 <= cx <= 0.82 and 0.32 <= cy <= 0.68:
|
| items.append((y1, x1, text))
|
| texts = []
|
| seen = set()
|
| for _, _, text in sorted(items, key=lambda value: (value[0], value[1])):
|
| if text not in seen:
|
| texts.append(text)
|
| seen.add(text)
|
| joined = " ".join(texts)
|
| has_dialog_action = "取消" in joined and "确定" in joined
|
| has_dialog_prompt = any(term in joined for term in ["确认", "是否", "删除", "提示", "商品吗"])
|
| return texts if has_dialog_action and has_dialog_prompt else []
|
|
|
|
|
| def ordered_ocr_texts(row: Dict[str, Any], max_items: int = 14) -> List[str]:
|
| dialog_texts = center_dialog_texts(row)
|
| candidates: List[Tuple[float, float, str]] = []
|
| for item in row.get("ocr_items", []) or []:
|
| text = safe_text(item.get("text"))[:80]
|
| if skip_summary_text(text):
|
| continue
|
| conf = float(item.get("conf", item.get("ocr_conf", 1.0)) or 1.0)
|
| if conf < 0.5:
|
| continue
|
| y_pos, x_pos = text_bbox(item)
|
| candidates.append((y_pos, x_pos, text))
|
| seen = set()
|
| texts: List[str] = []
|
| for text in dialog_texts:
|
| if text not in seen:
|
| texts.append(text)
|
| seen.add(text)
|
| for _, _, text in sorted(candidates, key=lambda value: (value[0], value[1])):
|
| if text in seen:
|
| continue
|
| seen.add(text)
|
| texts.append(text)
|
| if len(texts) >= max_items:
|
| break
|
| return texts
|
|
|
|
|
| def gui_item_index(row: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
|
| index: Dict[str, Dict[str, Any]] = {}
|
| for item in list(row.get("ui_items", []) or []) + list(row.get("ocr_items", []) or []):
|
| for key in ["id", "ocr_id"]:
|
| item_id = safe_text(item.get(key))
|
| if item_id and item_id not in index:
|
| index[item_id] = item
|
| return index
|
|
|
|
|
| def is_passive_search_text(text: str) -> bool:
|
| return any(term in text for term in PASSIVE_SEARCH_TERMS)
|
|
|
|
|
| def clean_function_list(row: Dict[str, Any], funcs: List[Any], max_items: int) -> List[Dict[str, Any]]:
|
| index = gui_item_index(row)
|
| cleaned: List[Dict[str, Any]] = []
|
| seen = set()
|
| for item in funcs or []:
|
| if not isinstance(item, dict):
|
| name = safe_text(item)
|
| evidence_ids: List[str] = []
|
| else:
|
| name = safe_text(item.get("name"))
|
| evidence_ids = [safe_text(value) for value in item.get("evidence_ids", []) if safe_text(value)]
|
| if not name or name in seen:
|
| continue
|
| evidence_texts = [safe_text(index.get(value, {}).get("text")) for value in evidence_ids]
|
| if any(skip_summary_text(text) for text in evidence_texts if text):
|
| continue
|
| if "搜索" in name and evidence_texts and all(is_passive_search_text(text) for text in evidence_texts):
|
| continue
|
| cleaned.append({"name": name, "evidence_ids": evidence_ids})
|
| seen.add(name)
|
| if len(cleaned) >= max_items:
|
| break
|
| return cleaned
|
|
|
|
|
| def collect_gui_functions(row: Dict[str, Any], max_items: int) -> List[Dict[str, Any]]:
|
| funcs: List[Dict[str, Any]] = []
|
| seen = set()
|
| for item in row.get("ui_items", []) or []:
|
| text = safe_text(item.get("text"))
|
| evidence_id = safe_text(item.get("id") or item.get("ocr_id"))
|
| if not text or not evidence_id or skip_summary_text(text):
|
| continue
|
| name = ""
|
| if (text in {"搜索", "搜"} or text.startswith("搜索")) and not is_passive_search_text(text):
|
| name = "搜索"
|
| else:
|
| for keyword in GUI_FUNCTION_KEYWORDS:
|
| if keyword in text:
|
| name = keyword
|
| break
|
| if not name or name in seen:
|
| continue
|
| funcs.append({"name": name, "evidence_ids": [evidence_id]})
|
| seen.add(name)
|
| if len(funcs) >= max_items:
|
| break
|
| return funcs
|
|
|
|
|
| def clean_evidence_ids(row: Dict[str, Any], evidence_ids: List[Any], max_items: int) -> List[str]:
|
| index = gui_item_index(row)
|
| cleaned: List[str] = []
|
| seen = set()
|
| for value in evidence_ids or []:
|
| evidence_id = safe_text(value)
|
| if not evidence_id or evidence_id in seen:
|
| continue
|
| text = safe_text(index.get(evidence_id, {}).get("text"))
|
| if text and skip_summary_text(text):
|
| continue
|
| cleaned.append(evidence_id)
|
| seen.add(evidence_id)
|
| if len(cleaned) >= max_items:
|
| break
|
| return cleaned
|
|
|
|
|
| def infer_page_type(texts: List[str]) -> str:
|
| joined = " ".join(texts)
|
| if "历史搜索" in joined or "搜索" in joined:
|
| if "热榜" in joined or "热搜" in joined or "猜你喜欢" in joined:
|
| return "搜索和推荐内容页面"
|
| return "搜索相关页面"
|
| if any(term in joined for term in ["购物车", "下单", "购买", "商品", "价格", "优惠"]):
|
| return "购物页面"
|
| if any(term in joined for term in ["热榜", "榜单", "猜你喜欢", "推荐"]):
|
| return "推荐内容页面"
|
| if any(term in joined for term in ["评论", "点赞", "收藏", "分享"]):
|
| return "内容互动页面"
|
| if any(term in joined for term in ["我的", "订单", "设置", "账号", "会员"]):
|
| return "个人或设置页面"
|
| return "移动应用页面"
|
|
|
|
|
| def quoted_join(values: List[str], limit: int) -> str:
|
| return "、".join(f"“{value}”" for value in values[:limit])
|
|
|
|
|
| def build_grounded_summary(row: Dict[str, Any]) -> str:
|
| texts = ordered_ocr_texts(row, max_items=14)
|
| dialog_texts = center_dialog_texts(row)
|
| app = display_app_name(str(row.get("app") or ""))
|
| page_type = infer_page_type(texts)
|
| subject = f"{app}的{page_type}" if app else page_type
|
| if not texts:
|
| return f"这张截图展示的是{subject},但当前没有识别到足够清晰的屏幕文字。"
|
| if dialog_texts:
|
| prompt = next((text for text in dialog_texts if any(term in text for term in ["确认", "是否", "删除", "商品吗"])), dialog_texts[0])
|
| actions = [text for text in dialog_texts if text in {"取消", "确定", "删除", "关闭"}]
|
| action_text = f",提供{quoted_join(actions, 4)}等按钮" if actions else ""
|
| return f"这张截图显示{subject}上弹出确认对话框,提示“{prompt}”{action_text};背景是购物车商品列表。"
|
| primary = quoted_join(texts, 5)
|
| summary = f"这张截图主要是{subject},屏幕上能看到{primary}等文字。"
|
| extra = texts[5:10]
|
| if extra:
|
| summary += f" 下方还出现{quoted_join(extra, 5)}等条目。"
|
| return summary
|
|
|
|
|
| def merge_items(primary: List[Any], secondary: List[Any], max_items: int) -> List[Any]:
|
| merged: List[Any] = []
|
| seen = set()
|
| for item in list(primary or []) + list(secondary or []):
|
| if isinstance(item, dict):
|
| key = json.dumps(item, ensure_ascii=False, sort_keys=True)
|
| else:
|
| key = safe_text(item)
|
| if not key or key in seen:
|
| continue
|
| seen.add(key)
|
| merged.append(item)
|
| if len(merged) >= max_items:
|
| break
|
| return merged
|
|
|
|
|
| def gui_ground_prediction(row: Dict[str, Any], pred_obj: Optional[Dict[str, Any]], args: argparse.Namespace) -> Dict[str, Any]:
|
| pred = pred_obj if isinstance(pred_obj, dict) else {}
|
| template = template_prediction(row, max_visible=args.max_visible_text)
|
| grounded = {
|
| "画面总结": safe_text(pred.get("画面总结") or pred.get("summary_zh") or pred.get("summary")),
|
| "可见文字": pred.get("可见文字") or pred.get("visible_text") or [],
|
| "互动数据": pred.get("互动数据") or pred.get("interaction_data") or [],
|
| "功能入口": pred.get("功能入口") or pred.get("ui_functions") or [],
|
| "关键证据": pred.get("关键证据") or pred.get("key_ui_clues") or pred.get("evidence") or [],
|
| }
|
| model_summary = safe_text(grounded.get("画面总结"))
|
| if row.get("ocr_items") and (args.gui_summary_mode == "ocr" or (args.gui_summary_mode == "auto" and not model_summary)):
|
| grounded["画面总结"] = build_grounded_summary(row)
|
| if row.get("ocr_items"):
|
| grounded["可见文字"] = ordered_ocr_texts(row, max_items=args.max_visible_text) |
| max_functions = int(getattr(args, "structured_max_functions", None) or 12) |
| gui_functions = collect_gui_functions(row, max_functions) |
| model_functions = clean_function_list(row, grounded.get("功能入口", []), max_functions) |
| if args.merge_ocr_functions: |
| grounded["功能入口"] = merge_items( |
| model_functions, |
| gui_functions, |
| max_functions, |
| ) |
| elif not grounded.get("功能入口"):
|
| grounded["功能入口"] = gui_functions
|
| else:
|
| grounded["功能入口"] = model_functions
|
| if not grounded.get("互动数据"):
|
| grounded["互动数据"] = template.get("互动数据", [])
|
| function_evidence = [] |
| for function in grounded.get("功能入口", []) or []: |
| if isinstance(function, dict): |
| function_evidence.extend(function.get("evidence_ids", []) or []) |
| evidence_candidates = function_evidence + list(grounded.get("关键证据", []) or []) |
| if not evidence_candidates: |
| evidence_candidates = list(template.get("关键证据", []) or []) |
| grounded["关键证据"] = clean_evidence_ids(row, evidence_candidates, 8) |
| if not grounded.get("画面总结"):
|
| grounded["画面总结"] = template.get("画面总结", "")
|
| return grounded
|
|
|
|
|
| def parse_multipart(headers: Any, body: bytes) -> Dict[str, Any]:
|
| content_type = headers.get("Content-Type", "")
|
| match = re.search(r"boundary=(?P<q>\"?)([^\";]+)(?P=q)", content_type)
|
| if not match:
|
| raise ValueError("Missing multipart boundary.")
|
| boundary = ("--" + match.group(2)).encode("utf-8")
|
| fields: Dict[str, Any] = {}
|
| for part in body.split(boundary):
|
| part = part.strip(b"\r\n")
|
| if not part or part == b"--":
|
| continue
|
| header_blob, _, value = part.partition(b"\r\n\r\n")
|
| if not header_blob:
|
| continue
|
| header_text = header_blob.decode("utf-8", errors="replace")
|
| disposition = ""
|
| for line in header_text.split("\r\n"):
|
| if line.lower().startswith("content-disposition:"):
|
| disposition = line.split(":", 1)[1]
|
| break
|
| name_match = re.search(r'name="([^"]+)"', disposition)
|
| if not name_match:
|
| continue
|
| field_name = name_match.group(1)
|
| filename_match = re.search(r'filename="([^"]*)"', disposition)
|
| if filename_match:
|
| fields[field_name] = {
|
| "filename": filename_match.group(1),
|
| "content": value.rstrip(b"\r\n"),
|
| }
|
| else:
|
| fields[field_name] = value.rstrip(b"\r\n").decode("utf-8", errors="replace")
|
| return fields
|
|
|
|
|
| class RichGuiPredictor:
|
| def __init__(self, args: argparse.Namespace):
|
| self.args = args
|
| self.device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu"))
|
| self.model = None
|
| self.tokenizer = None
|
| self.image_processor = None
|
| self.ckpt_args = None
|
| self.collator = None
|
| self.lock = threading.Lock()
|
|
|
| def load_model(self) -> None:
|
| if self.args.template_only or self.model is not None:
|
| return
|
| if not self.args.checkpoint:
|
| raise FileNotFoundError("Checkpoint not set. Train a natural multimodal checkpoint first or pass --checkpoint.")
|
| checkpoint = Path(self.args.checkpoint)
|
| if not checkpoint.exists():
|
| raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
|
| self.model, self.tokenizer, self.image_processor, self.ckpt_args = load_rich_checkpoint(str(checkpoint), self.device)
|
| self.apply_runtime_args(self.ckpt_args)
|
| self.collator = RichCollator(self.tokenizer, self.image_processor, self.ckpt_args)
|
|
|
| def apply_runtime_args(self, ckpt_args: argparse.Namespace) -> None: |
| optional_names = [ |
| "generation_no_repeat_ngram_size", |
| "generation_repetition_penalty", |
| "generation_block_extra_ids", |
| "generation_block_title_prefix", |
| "generation_force_json_start", |
| "context_summary_repair", |
| "canonicalize_targets", |
| "drop_bare_search_functions", |
| "structured_function_threshold", |
| "structured_search_threshold", |
| "structured_max_functions", |
| "structured_strict_search_candidates", |
| "structured_evidence_threshold", |
| "structured_max_evidence", |
| "structured_evidence_fallback_top1", |
| ] |
| for name in optional_names: |
| value = getattr(self.args, name, None) |
| if value is not None and value != "": |
| setattr(ckpt_args, name, value) |
| for name in ["structured_function_mode", "structured_evidence_mode"]: |
| value = getattr(self.args, name, "") |
| if value: |
| setattr(ckpt_args, name, value) |
| ckpt_args.num_workers = 0 |
|
|
| def save_image(self, upload: Dict[str, Any]) -> Path:
|
| raw = upload.get("content") or b""
|
| if not raw:
|
| raise ValueError("Uploaded image is empty.")
|
| filename = safe_upload_name(upload.get("filename", "upload.png"))
|
| suffix = Path(filename).suffix.lower()
|
| if suffix and suffix not in ALLOWED_IMAGE_EXTS:
|
| raise ValueError(f"Unsupported image type: {suffix}")
|
| upload_dir = Path(self.args.upload_dir)
|
| upload_dir.mkdir(parents=True, exist_ok=True)
|
| with Image.open(BytesIO(raw)) as img:
|
| image = img.convert("RGB")
|
| stamp = time.strftime("%Y%m%d_%H%M%S")
|
| upload_path = upload_dir / f"{stamp}_{int(time.time() * 1000) % 1000:03d}_{Path(filename).stem}.png"
|
| image.save(upload_path)
|
| return upload_path
|
|
|
| def build_row(self, image_path: Path, app: str, focus: str, ocr_engine: str) -> Tuple[Dict[str, Any], Optional[str]]:
|
| model_instruction = safe_text(focus) if self.args.use_focus_in_model else ""
|
| row: Dict[str, Any] = {
|
| "screen_id": image_path.stem,
|
| "image_path": str(image_path),
|
| "app": safe_text(app) or "移动应用",
|
| "instruction": model_instruction,
|
| "display_focus": safe_text(focus),
|
| "target": {
|
| "summary_zh": "",
|
| "visible_text": [],
|
| "interaction_data": [],
|
| "ui_functions": [],
|
| "key_ui_clues": [],
|
| },
|
| "ocr_items": [],
|
| "ui_items": [],
|
| "weak_evidence_ids": [],
|
| }
|
| ocr_error = None
|
| if ocr_engine != "none":
|
| try:
|
| cache_args = argparse.Namespace(ocr_engine=ocr_engine, ocr_lang=self.args.ocr_lang)
|
| image_sha = sha256_file(image_path)
|
| ocr_items = load_ocr_items(image_path, image_sha, Path(self.args.ocr_cache_dir), cache_args)
|
| ocr_items = filter_ocr_items(ocr_items, self.args.min_ocr_conf)[: self.args.max_ocr_items]
|
| row["ocr_items"] = ocr_items
|
| row["ui_items"] = build_ocr_ui_items(row, ocr_items, self.args.max_ui_items)
|
| row["weak_evidence_ids"] = [item.get("id") for item in row["ui_items"][:8] if item.get("id")]
|
| except Exception as exc:
|
| ocr_error = str(exc)
|
| return row, ocr_error
|
|
|
| @torch.no_grad()
|
| def summarize(self, image_path: Path, app: str, focus: str, ocr_engine: str) -> Dict[str, Any]:
|
| start = time.perf_counter()
|
| row, ocr_error = self.build_row(image_path, app, focus, ocr_engine)
|
| if self.args.template_only:
|
| pred = template_prediction(row, max_visible=self.args.max_visible_text)
|
| pred = gui_ground_prediction(row, pred, self.args)
|
| result = row_result(
|
| row=row,
|
| raw_text=json.dumps(pred, ensure_ascii=False),
|
| pred_obj=pred,
|
| json_valid=True,
|
| evidence_scores=None,
|
| allow_template_fallback=False,
|
| source="template",
|
| )
|
| else:
|
| with self.lock:
|
| self.load_model()
|
| batch = self.collator([row])
|
| batch = move_batch(batch, self.device)
|
| text = self.model.generate_text(
|
| batch,
|
| self.tokenizer,
|
| num_beams=self.args.num_beams,
|
| max_new_tokens=self.args.max_new_tokens,
|
| )[0]
|
| _, _, elem_tokens, elem_key_padding = self.model.build_memory(batch)
|
| masks = (~elem_key_padding)[0].detach().cpu().numpy()
|
| evidence_head = torch.sigmoid(self.model.evidence_head(elem_tokens).squeeze(-1))[0].detach().cpu()
|
| function_head = torch.sigmoid(self.model.ui_function_head(elem_tokens).squeeze(-1))[0].detach().cpu()
|
| search_head = torch.sigmoid(self.model.search_function_head(elem_tokens).squeeze(-1))[0].detach().cpu()
|
| output_is_summary = target_schema_is_summary(getattr(self.ckpt_args, "target_schema", "zh"))
|
| output_is_natural = target_schema_is_natural_text(getattr(self.ckpt_args, "target_schema", "zh"))
|
| if output_is_summary:
|
| pred_obj = prediction_from_summary(row, text)
|
| ok = True
|
| elif output_is_natural:
|
| pred_obj = natural_prediction_from_text(text)
|
| ok = bool(pred_obj.get("画面总结"))
|
| else:
|
| pred_obj, ok = safe_json_loads(text)
|
| if pred_obj is None:
|
| elements = row.get("ui_items", []) or []
|
| ranked = torch.argsort(evidence_head, descending=True).tolist()
|
| top_ids = []
|
| for idx in ranked:
|
| if idx < len(elements) and idx < len(masks) and masks[idx]:
|
| evidence_id = safe_text(elements[idx].get("id") or elements[idx].get("ocr_id"))
|
| if evidence_id:
|
| top_ids.append(evidence_id)
|
| if len(top_ids) >= self.args.top_k_clues:
|
| break
|
| pred_obj = {"关键证据": top_ids}
|
| if self.args.context_summary_repair:
|
| pred_obj, _ = repair_prediction_with_context(row, pred_obj)
|
| ok = True
|
| pred_obj = apply_structured_function_predictions(row, pred_obj, function_head, search_head, self.ckpt_args) |
| pred_obj = apply_structured_evidence_predictions(row, pred_obj, evidence_head, self.ckpt_args) |
| pred_obj = gui_ground_prediction(row, pred_obj, self.args) |
| evidence_scores: Dict[str, float] = {}
|
| for idx, elem in enumerate(row.get("ui_items", []) or []):
|
| if idx < len(masks) and masks[idx]:
|
| evidence_id = safe_text(elem.get("id") or elem.get("ocr_id"))
|
| if evidence_id:
|
| evidence_scores[evidence_id] = float(evidence_head[idx])
|
| result = row_result(
|
| row=row,
|
| raw_text=text,
|
| pred_obj=pred_obj,
|
| json_valid=ok,
|
| evidence_scores=evidence_scores,
|
| allow_template_fallback=self.args.allow_template_fallback,
|
| source="model",
|
| )
|
| result["ocr_count"] = len(row.get("ocr_items", []) or [])
|
| result["ui_item_count"] = len(row.get("ui_items", []) or [])
|
| result["ocr_error"] = ocr_error
|
| result["image_url"] = f"/uploads/{image_path.name}"
|
| result["focus"] = safe_text(focus)
|
| if self.args.template_only:
|
| result["model_output_mode"] = "template"
|
| elif self.ckpt_args is not None and target_schema_is_summary(getattr(self.ckpt_args, "target_schema", "zh")):
|
| result["model_output_mode"] = "summary"
|
| elif self.ckpt_args is not None and target_schema_is_natural_text(getattr(self.ckpt_args, "target_schema", "zh")):
|
| result["model_output_mode"] = "natural_text"
|
| else:
|
| result["model_output_mode"] = "json"
|
| result["display_mode"] = self.args.gui_summary_mode if row.get("ocr_items") else "model"
|
| result["elapsed_sec"] = round(time.perf_counter() - start, 3)
|
| return result
|
|
|
|
|
| class RichGuiHandler(BaseHTTPRequestHandler):
|
| server_version = "RichGui/1.0"
|
|
|
| def send_payload(self, status: int, body: bytes, content_type: str) -> None:
|
| self.send_response(status)
|
| self.send_header("Content-Type", content_type)
|
| self.send_header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0")
|
| self.send_header("Pragma", "no-cache")
|
| self.send_header("Content-Length", str(len(body)))
|
| self.end_headers()
|
| self.wfile.write(body)
|
|
|
| def do_GET(self) -> None:
|
| parsed = urlparse(self.path)
|
| if parsed.path in {"/", "/index.html"}:
|
| self.send_payload(HTTPStatus.OK, HTML_PAGE.encode("utf-8"), "text/html; charset=utf-8")
|
| return
|
| if parsed.path == "/api/health": |
| status, body, content_type = json_bytes( |
| { |
| "ok": True, |
| "template_only": self.server.predictor.args.template_only, |
| "checkpoint": None if self.server.predictor.args.template_only else self.server.predictor.args.checkpoint, |
| "structured_function_mode": self.server.predictor.args.structured_function_mode, |
| "structured_function_threshold": self.server.predictor.args.structured_function_threshold, |
| "structured_search_threshold": self.server.predictor.args.structured_search_threshold, |
| "structured_evidence_mode": self.server.predictor.args.structured_evidence_mode, |
| "structured_evidence_threshold": self.server.predictor.args.structured_evidence_threshold, |
| } |
| ) |
| self.send_payload(status, body, content_type) |
| return |
| if parsed.path.startswith("/uploads/"):
|
| filename = safe_upload_name(unquote(parsed.path.removeprefix("/uploads/")))
|
| path = Path(self.server.predictor.args.upload_dir) / filename
|
| if path.exists() and path.is_file():
|
| content_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
|
| self.send_payload(HTTPStatus.OK, path.read_bytes(), content_type)
|
| else:
|
| status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND)
|
| self.send_payload(status, body, content_type)
|
| return
|
| status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND)
|
| self.send_payload(status, body, content_type)
|
|
|
| def do_POST(self) -> None:
|
| parsed = urlparse(self.path)
|
| if parsed.path != "/api/summarize":
|
| status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND)
|
| self.send_payload(status, body, content_type)
|
| return
|
| try:
|
| length = int(self.headers.get("Content-Length", "0"))
|
| if length > self.server.predictor.args.max_upload_mb * 1024 * 1024:
|
| raise ValueError(f"Image is larger than {self.server.predictor.args.max_upload_mb} MB.")
|
| fields = parse_multipart(self.headers, self.rfile.read(length))
|
| upload = fields.get("image")
|
| if not isinstance(upload, dict):
|
| raise ValueError("Missing image field.")
|
| image_path = self.server.predictor.save_image(upload)
|
| app = safe_text(fields.get("app")) or "移动应用"
|
| focus = safe_text(fields.get("focus")) or safe_text(fields.get("instruction"))
|
| ocr_engine = safe_text(fields.get("ocr_engine")) or self.server.predictor.args.ocr_engine
|
| if ocr_engine not in {"none", "paddleocr"}:
|
| ocr_engine = self.server.predictor.args.ocr_engine
|
| result = self.server.predictor.summarize(image_path, app, focus, ocr_engine)
|
| status, body, content_type = json_bytes(result, HTTPStatus.OK)
|
| except Exception as exc:
|
| status, body, content_type = json_bytes({"error": str(exc)}, HTTPStatus.BAD_REQUEST)
|
| self.send_payload(status, body, content_type)
|
|
|
| def log_message(self, fmt: str, *args: Any) -> None:
|
| if not self.server.predictor.args.quiet:
|
| super().log_message(fmt, *args)
|
|
|
|
|
| class RichGuiServer(ThreadingHTTPServer):
|
| def __init__(self, server_address: Tuple[str, int], handler_class: Any, predictor: RichGuiPredictor):
|
| super().__init__(server_address, handler_class)
|
| self.predictor = predictor
|
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Start a local GUI for rich screenshot summarization.", |
| formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| ) |
| parser.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT)
|
| parser.add_argument("--template_only", action="store_true")
|
| parser.add_argument("--host", default="127.0.0.1")
|
| parser.add_argument("--port", type=int, default=7860)
|
| parser.add_argument("--open_browser", type=str_to_bool, default=True)
|
| parser.add_argument("--quiet", action="store_true")
|
| parser.add_argument("--upload_dir", default="outputs/rich_gui/uploads")
|
| parser.add_argument("--ocr_cache_dir", default="data/rich_cmgui/cache/gui_ocr")
|
| parser.add_argument("--ocr_engine", choices=["none", "paddleocr"], default="paddleocr")
|
| parser.add_argument("--ocr_lang", default="ch")
|
| parser.add_argument("--min_ocr_conf", type=float, default=0.5)
|
| parser.add_argument("--max_ocr_items", type=int, default=120)
|
| parser.add_argument("--max_ui_items", type=int, default=48)
|
| parser.add_argument("--max_upload_mb", type=int, default=20)
|
| parser.add_argument("--device", default="") |
| parser.add_argument("--num_beams", type=int, default=1) |
| parser.add_argument("--max_new_tokens", type=int, default=256) |
| parser.add_argument("--top_k_clues", type=int, default=5)
|
| parser.add_argument("--max_visible_text", type=int, default=12)
|
| parser.add_argument("--gui_summary_mode", choices=["model", "ocr", "auto"], default="model")
|
| parser.add_argument("--merge_ocr_functions", type=str_to_bool, default=True)
|
| parser.add_argument("--use_focus_in_model", type=str_to_bool, default=False)
|
| parser.add_argument("--generation_no_repeat_ngram_size", type=int, default=3) |
| parser.add_argument("--generation_repetition_penalty", type=float, default=1.1) |
| parser.add_argument("--generation_block_extra_ids", type=str_to_bool, default=True) |
| parser.add_argument("--generation_block_title_prefix", type=str_to_bool, default=True) |
| parser.add_argument("--generation_force_json_start", type=str_to_bool, default=False) |
| parser.add_argument("--context_summary_repair", type=str_to_bool, default=None) |
| parser.add_argument("--canonicalize_targets", type=str_to_bool, default=None) |
| parser.add_argument("--drop_bare_search_functions", type=str_to_bool, default=None) |
| parser.add_argument("--structured_function_mode", choices=["", "decoder", "heads"], default="heads") |
| parser.add_argument("--structured_function_threshold", type=float, default=DEFAULT_FUNCTION_THRESHOLD) |
| parser.add_argument("--structured_search_threshold", type=float, default=DEFAULT_SEARCH_THRESHOLD) |
| parser.add_argument("--structured_max_functions", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS) |
| parser.add_argument("--structured_strict_search_candidates", type=str_to_bool, default=None) |
| parser.add_argument("--structured_evidence_mode", choices=["", "decoder", "heads"], default="heads") |
| parser.add_argument("--structured_evidence_threshold", type=float, default=DEFAULT_EVIDENCE_THRESHOLD) |
| parser.add_argument("--structured_max_evidence", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS) |
| parser.add_argument("--structured_evidence_fallback_top1", type=str_to_bool, default=False) |
| parser.add_argument("--allow_template_fallback", type=str_to_bool, default=False) |
| args = parser.parse_args() |
| if not args.template_only and not args.checkpoint: |
| checkpoint = find_latest_rich_checkpoint() |
| if checkpoint is None: |
| raise FileNotFoundError("No stage3/stage4 rich checkpoint found. Train first or pass --checkpoint.") |
| args.checkpoint = str(checkpoint) |
| return args |
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
| predictor = RichGuiPredictor(args)
|
| last_error: Optional[OSError] = None
|
| server: Optional[RichGuiServer] = None
|
| for port in range(args.port, args.port + 20):
|
| try:
|
| server = RichGuiServer((args.host, port), RichGuiHandler, predictor)
|
| break
|
| except OSError as exc:
|
| last_error = exc
|
| if server is None:
|
| raise RuntimeError(f"Could not bind a local port starting at {args.port}: {last_error}")
|
| actual_port = server.server_address[1] |
| url = f"http://{args.host}:{actual_port}/" |
| print( |
| json.dumps( |
| { |
| "url": url, |
| "template_only": args.template_only, |
| "checkpoint": None if args.template_only else args.checkpoint, |
| "num_beams": args.num_beams, |
| "max_new_tokens": args.max_new_tokens, |
| "structured_function_mode": args.structured_function_mode, |
| "structured_function_threshold": args.structured_function_threshold, |
| "structured_search_threshold": args.structured_search_threshold, |
| "structured_evidence_mode": args.structured_evidence_mode, |
| "structured_evidence_threshold": args.structured_evidence_threshold, |
| }, |
| ensure_ascii=False, |
| ) |
| ) |
| if args.open_browser:
|
| threading.Timer(0.5, lambda: webbrowser.open(url)).start()
|
| server.serve_forever()
|
|
|
|
|
| if __name__ == "__main__":
|
| main() |
|
|