hsq12138's picture
Upload CMGUI stage3 screen-grounded summarizer checkpoint
2f0e115 verified
#!/usr/bin/env python
"""Local web GUI for rich CMGUI screenshot summarization."""
from __future__ import annotations
import argparse
import json
import mimetypes
import re
import sys
import threading
import time
import webbrowser
from http import HTTPStatus
from http.server import BaseHTTPRequestHandler, ThreadingHTTPServer
from io import BytesIO
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from urllib.parse import unquote, urlparse
import torch
from PIL import Image
SCRIPT_DIR = Path(__file__).resolve().parent
if str(SCRIPT_DIR) not in sys.path:
sys.path.insert(0, str(SCRIPT_DIR))
from enrich_rich_ocr_evidence import build_ocr_ui_items, filter_ocr_items # noqa: E402
from infer_rich import find_latest_rich_checkpoint, row_result, template_prediction # noqa: E402
from prepare_rich_data import load_ocr_items, safe_text, sha256_file # noqa: E402
from train_rich import ( # noqa: E402
RichCollator,
apply_structured_evidence_predictions,
apply_structured_function_predictions,
load_rich_checkpoint,
move_batch,
natural_prediction_from_text,
prediction_from_summary,
repair_prediction_with_context,
safe_json_loads,
target_schema_is_natural_text,
target_schema_is_summary,
)
DEFAULT_CHECKPOINT = ""
DEFAULT_FUNCTION_THRESHOLD = 0.20
DEFAULT_SEARCH_THRESHOLD = 0.20
DEFAULT_EVIDENCE_THRESHOLD = 0.50
DEFAULT_MAX_STRUCTURED_ITEMS = 8
ALLOWED_IMAGE_EXTS = {".png", ".jpg", ".jpeg", ".webp", ".bmp"}
GENERIC_APP_NAMES = {"", "移动应用", "手机应用", "应用", "App", "APP", "app"}
SUMMARY_SKIP_TEXTS = {
"ADB Keyboard {ON}",
"反馈",
"共享中",
"分享中",
}
GUI_FUNCTION_KEYWORDS = ["确定", "取消", "删除", "分享", "评论", "收藏", "点赞", "关注", "返回", "首页", "消息", "购物车", "购买", "下单", "关闭"]
PASSIVE_SEARCH_TERMS = ["历史搜索", "搜索历史", "热门搜索", "热搜", "猜你想搜", "搜索推荐", "搜索记录"]
HTML_PAGE = r"""
<!doctype html>
<html lang="zh-CN">
<head>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
<title>Rich Screenshot Summarizer</title>
<style>
:root {
color-scheme: light;
--bg: #f6f7f9;
--panel: #ffffff;
--line: #d8dde6;
--text: #1d2433;
--muted: #5e6a7d;
--accent: #166a5a;
--accent-2: #b35c16;
--danger: #b3261e;
--shadow: 0 10px 30px rgba(29, 36, 51, 0.08);
}
* { box-sizing: border-box; }
body {
margin: 0;
min-height: 100vh;
background: var(--bg);
color: var(--text);
font-family: "Segoe UI", "Microsoft YaHei", Arial, sans-serif;
letter-spacing: 0;
}
header {
border-bottom: 1px solid var(--line);
background: rgba(255, 255, 255, 0.92);
position: sticky;
top: 0;
z-index: 10;
backdrop-filter: blur(10px);
}
.bar {
max-width: 1180px;
margin: 0 auto;
padding: 14px 20px;
display: flex;
align-items: center;
justify-content: space-between;
gap: 16px;
}
h1 {
font-size: 18px;
line-height: 1.25;
margin: 0;
font-weight: 650;
}
.status {
color: var(--muted);
font-size: 13px;
white-space: nowrap;
}
main {
max-width: 1180px;
margin: 0 auto;
padding: 22px 20px 32px;
display: grid;
grid-template-columns: 360px minmax(0, 1fr);
gap: 18px;
}
section {
background: var(--panel);
border: 1px solid var(--line);
border-radius: 8px;
box-shadow: var(--shadow);
}
.controls { padding: 18px; }
.preview { padding: 14px; }
label {
display: block;
color: var(--muted);
font-size: 13px;
margin: 0 0 7px;
}
input[type="file"], input[type="text"], textarea, select {
width: 100%;
border: 1px solid var(--line);
border-radius: 6px;
padding: 10px 11px;
font: inherit;
font-size: 14px;
background: #fff;
color: var(--text);
min-height: 40px;
}
textarea { resize: vertical; min-height: 76px; }
.field { margin-bottom: 14px; }
.row { display: grid; grid-template-columns: 1fr 1fr; gap: 10px; }
button {
width: 100%;
border: 1px solid #10584a;
border-radius: 6px;
background: var(--accent);
color: #fff;
padding: 11px 14px;
font: inherit;
font-size: 14px;
font-weight: 650;
cursor: pointer;
min-height: 42px;
}
button:disabled { cursor: wait; opacity: 0.72; }
.image-box {
border: 1px solid var(--line);
border-radius: 8px;
background: #eef1f5;
overflow: hidden;
min-height: 260px;
display: grid;
place-items: center;
}
.image-box img {
max-width: 100%;
max-height: 72vh;
display: block;
object-fit: contain;
}
.placeholder { color: var(--muted); font-size: 14px; padding: 30px; text-align: center; }
.result { padding: 18px; display: grid; gap: 16px; }
.summary {
font-size: 17px;
line-height: 1.65;
padding-bottom: 8px;
border-bottom: 1px solid var(--line);
}
.grid {
display: grid;
grid-template-columns: repeat(2, minmax(0, 1fr));
gap: 14px;
}
.block h2 {
margin: 0 0 8px;
font-size: 14px;
line-height: 1.35;
color: var(--accent);
}
ul { margin: 0; padding-left: 18px; }
li { margin: 4px 0; line-height: 1.45; overflow-wrap: anywhere; }
.meta {
display: flex;
flex-wrap: wrap;
gap: 8px;
color: var(--muted);
font-size: 12px;
}
.pill {
border: 1px solid var(--line);
border-radius: 999px;
padding: 3px 8px;
background: #fafbfc;
}
details {
border-top: 1px solid var(--line);
padding-top: 12px;
}
summary { cursor: pointer; color: var(--accent-2); font-size: 13px; }
pre {
white-space: pre-wrap;
overflow-wrap: anywhere;
background: #f3f5f7;
border: 1px solid var(--line);
border-radius: 6px;
padding: 12px;
max-height: 360px;
overflow: auto;
font-size: 12px;
}
.error {
color: var(--danger);
border: 1px solid rgba(179, 38, 30, 0.35);
background: rgba(179, 38, 30, 0.06);
border-radius: 6px;
padding: 11px;
line-height: 1.45;
}
@media (max-width: 860px) {
main { grid-template-columns: 1fr; }
.grid { grid-template-columns: 1fr; }
.bar { align-items: flex-start; flex-direction: column; }
.status { white-space: normal; }
}
</style>
</head>
<body>
<header>
<div class="bar">
<h1>Rich Screenshot Summarizer</h1>
<div id="status" class="status">就绪</div>
</div>
</header>
<main>
<section class="controls">
<form id="form">
<div class="field">
<label for="image">图片</label>
<input id="image" name="image" type="file" accept="image/*" required />
</div>
<div class="row">
<div class="field">
<label for="app">应用名/类型(可选)</label>
<input id="app" name="app" type="text" placeholder="可留空" />
</div>
<div class="field">
<label for="ocr_engine">OCR</label>
<select id="ocr_engine" name="ocr_engine">
<option value="paddleocr" selected>paddleocr</option>
<option value="none">none</option>
</select>
</div>
</div>
<div class="field">
<label for="focus">关注点(可选)</label>
<textarea id="focus" name="focus" placeholder="可留空,例如:搜索结果、价格信息、按钮入口"></textarea>
</div>
<button id="submit" type="submit">生成总结</button>
</form>
</section>
<section class="preview">
<div class="image-box" id="imageBox"><div class="placeholder">未选择图片</div></div>
</section>
<section class="result" style="grid-column: 1 / -1;">
<div id="output" class="placeholder">结果会显示在这里</div>
</section>
</main>
<script>
const form = document.getElementById('form');
const imageInput = document.getElementById('image');
const imageBox = document.getElementById('imageBox');
const output = document.getElementById('output');
const statusEl = document.getElementById('status');
const submit = document.getElementById('submit');
imageInput.addEventListener('change', () => {
const file = imageInput.files && imageInput.files[0];
if (!file) {
imageBox.innerHTML = '<div class="placeholder">未选择图片</div>';
return;
}
const url = URL.createObjectURL(file);
imageBox.innerHTML = '';
const img = document.createElement('img');
img.src = url;
img.onload = () => URL.revokeObjectURL(url);
imageBox.appendChild(img);
});
function escapeHtml(value) {
return String(value ?? '').replace(/[&<>'"]/g, ch => ({
'&': '&amp;', '<': '&lt;', '>': '&gt;', "'": '&#39;', '"': '&quot;'
}[ch]));
}
function listItems(items, mapFn) {
if (!Array.isArray(items) || items.length === 0) return '<span class="placeholder">无</span>';
return '<ul>' + items.map(item => '<li>' + mapFn(item) + '</li>').join('') + '</ul>';
}
function renderResult(data) {
const pred = data.prediction || {};
const outputMode = data.model_output_mode || 'json';
const parseLabel = outputMode === 'summary' ? '输出 summary' : (outputMode === 'natural_text' ? ('自然文本 ' + (data.json_valid ? 'parsed' : 'fallback')) : (outputMode === 'template' ? '模板输出' : ('JSON ' + (data.json_valid ? 'valid' : 'fallback'))));
const visible = listItems(pred['可见文字'], x => escapeHtml(x));
const funcs = listItems(pred['功能入口'], item => {
const evidence = Array.isArray(item.evidence_ids) ? item.evidence_ids.join(', ') : '';
return escapeHtml(item.name || '') + (evidence ? ' <span class="pill">' + escapeHtml(evidence) + '</span>' : '');
});
const interactions = listItems(pred['互动数据'], item => {
const value = item.value ? ':' + item.value : '';
return escapeHtml((item.name || '') + value);
});
const clues = listItems(data.key_ui_clues, item => {
const score = typeof item.score === 'number' ? ' score=' + item.score.toFixed(3) : '';
const text = item.text ? ' ' + item.text : '';
return '<span class="pill">' + escapeHtml(item.element_id || '') + '</span>' + escapeHtml(text + score);
});
output.className = '';
output.innerHTML = `
<div class="summary">${escapeHtml(data.summary || pred['画面总结'] || '')}</div>
<div class="meta">
<span class="pill">${escapeHtml(data.source || '')}</span>
<span class="pill">${escapeHtml(outputMode)}</span>
<span class="pill">${escapeHtml(data.display_mode || 'model')}</span>
<span class="pill">${escapeHtml(parseLabel)}</span>
<span class="pill">OCR ${data.ocr_count ?? 0}</span>
<span class="pill">${Number(data.elapsed_sec || 0).toFixed(2)}s</span>
</div>
<div class="grid">
<div class="block"><h2>可见文字</h2>${visible}</div>
<div class="block"><h2>功能入口</h2>${funcs}</div>
<div class="block"><h2>互动数据</h2>${interactions}</div>
<div class="block"><h2>关键证据</h2>${clues}</div>
</div>
<details><summary>原始 JSON</summary><pre>${escapeHtml(JSON.stringify(data, null, 2))}</pre></details>
`;
if (data.image_url) {
imageBox.innerHTML = `<img src="${escapeHtml(data.image_url)}" alt="uploaded screenshot" />`;
}
}
form.addEventListener('submit', async event => {
event.preventDefault();
const formData = new FormData(form);
submit.disabled = true;
statusEl.textContent = '生成中';
output.className = 'placeholder';
output.textContent = '正在处理图片';
try {
const response = await fetch('/api/summarize', { method: 'POST', body: formData });
const data = await response.json();
if (!response.ok) throw new Error(data.error || '请求失败');
renderResult(data);
statusEl.textContent = '完成';
} catch (error) {
output.className = 'error';
output.textContent = error.message || String(error);
statusEl.textContent = '出错';
} finally {
submit.disabled = false;
}
});
</script>
</body>
</html>
"""
def str_to_bool(value: Any) -> bool:
return str(value).lower() in {"1", "true", "yes", "y"}
def json_bytes(obj: Dict[str, Any], status: int = 200) -> Tuple[int, bytes, str]:
return status, json.dumps(obj, ensure_ascii=False).encode("utf-8"), "application/json; charset=utf-8"
def safe_upload_name(filename: str) -> str:
name = Path(filename or "upload").name
name = re.sub(r"[^A-Za-z0-9._-]+", "_", name).strip("._")
return name or "upload"
def display_app_name(app: str) -> str:
text = safe_text(app)
return "" if text in GENERIC_APP_NAMES else text
def text_bbox(item: Dict[str, Any]) -> Tuple[float, float]:
bbox = item.get("bbox") or []
if isinstance(bbox, list) and len(bbox) >= 4:
try:
return float(bbox[1]), float(bbox[0])
except (TypeError, ValueError):
return 0.0, 0.0
return 0.0, 0.0
def skip_summary_text(text: str) -> bool:
if not text:
return True
if text in SUMMARY_SKIP_TEXTS:
return True
if re.fullmatch(r"[0-9::/ ._\-]+", text):
return True
if re.fullmatch(r"\d{1,2}:\d{2}", text):
return True
if re.fullmatch(r"[A-Za-z0-9%+\- ]{1,5}", text):
return True
if len(text) == 1 and text not in {"搜"}:
return True
if len(text) == 1 and not re.search(r"[\u4e00-\u9fffA-Za-z]", text):
return True
return False
def center_dialog_texts(row: Dict[str, Any]) -> List[str]:
items: List[Tuple[float, float, str]] = []
for item in row.get("ocr_items", []) or []:
text = safe_text(item.get("text"))[:80]
if not text or skip_summary_text(text):
continue
bbox = item.get("bbox") or []
if not (isinstance(bbox, list) and len(bbox) == 4):
continue
try:
x1, y1, x2, y2 = [float(value) for value in bbox]
except (TypeError, ValueError):
continue
cx = (x1 + x2) / 2.0
cy = (y1 + y2) / 2.0
if 0.18 <= cx <= 0.82 and 0.32 <= cy <= 0.68:
items.append((y1, x1, text))
texts = []
seen = set()
for _, _, text in sorted(items, key=lambda value: (value[0], value[1])):
if text not in seen:
texts.append(text)
seen.add(text)
joined = " ".join(texts)
has_dialog_action = "取消" in joined and "确定" in joined
has_dialog_prompt = any(term in joined for term in ["确认", "是否", "删除", "提示", "商品吗"])
return texts if has_dialog_action and has_dialog_prompt else []
def ordered_ocr_texts(row: Dict[str, Any], max_items: int = 14) -> List[str]:
dialog_texts = center_dialog_texts(row)
candidates: List[Tuple[float, float, str]] = []
for item in row.get("ocr_items", []) or []:
text = safe_text(item.get("text"))[:80]
if skip_summary_text(text):
continue
conf = float(item.get("conf", item.get("ocr_conf", 1.0)) or 1.0)
if conf < 0.5:
continue
y_pos, x_pos = text_bbox(item)
candidates.append((y_pos, x_pos, text))
seen = set()
texts: List[str] = []
for text in dialog_texts:
if text not in seen:
texts.append(text)
seen.add(text)
for _, _, text in sorted(candidates, key=lambda value: (value[0], value[1])):
if text in seen:
continue
seen.add(text)
texts.append(text)
if len(texts) >= max_items:
break
return texts
def gui_item_index(row: Dict[str, Any]) -> Dict[str, Dict[str, Any]]:
index: Dict[str, Dict[str, Any]] = {}
for item in list(row.get("ui_items", []) or []) + list(row.get("ocr_items", []) or []):
for key in ["id", "ocr_id"]:
item_id = safe_text(item.get(key))
if item_id and item_id not in index:
index[item_id] = item
return index
def is_passive_search_text(text: str) -> bool:
return any(term in text for term in PASSIVE_SEARCH_TERMS)
def clean_function_list(row: Dict[str, Any], funcs: List[Any], max_items: int) -> List[Dict[str, Any]]:
index = gui_item_index(row)
cleaned: List[Dict[str, Any]] = []
seen = set()
for item in funcs or []:
if not isinstance(item, dict):
name = safe_text(item)
evidence_ids: List[str] = []
else:
name = safe_text(item.get("name"))
evidence_ids = [safe_text(value) for value in item.get("evidence_ids", []) if safe_text(value)]
if not name or name in seen:
continue
evidence_texts = [safe_text(index.get(value, {}).get("text")) for value in evidence_ids]
if any(skip_summary_text(text) for text in evidence_texts if text):
continue
if "搜索" in name and evidence_texts and all(is_passive_search_text(text) for text in evidence_texts):
continue
cleaned.append({"name": name, "evidence_ids": evidence_ids})
seen.add(name)
if len(cleaned) >= max_items:
break
return cleaned
def collect_gui_functions(row: Dict[str, Any], max_items: int) -> List[Dict[str, Any]]:
funcs: List[Dict[str, Any]] = []
seen = set()
for item in row.get("ui_items", []) or []:
text = safe_text(item.get("text"))
evidence_id = safe_text(item.get("id") or item.get("ocr_id"))
if not text or not evidence_id or skip_summary_text(text):
continue
name = ""
if (text in {"搜索", "搜"} or text.startswith("搜索")) and not is_passive_search_text(text):
name = "搜索"
else:
for keyword in GUI_FUNCTION_KEYWORDS:
if keyword in text:
name = keyword
break
if not name or name in seen:
continue
funcs.append({"name": name, "evidence_ids": [evidence_id]})
seen.add(name)
if len(funcs) >= max_items:
break
return funcs
def clean_evidence_ids(row: Dict[str, Any], evidence_ids: List[Any], max_items: int) -> List[str]:
index = gui_item_index(row)
cleaned: List[str] = []
seen = set()
for value in evidence_ids or []:
evidence_id = safe_text(value)
if not evidence_id or evidence_id in seen:
continue
text = safe_text(index.get(evidence_id, {}).get("text"))
if text and skip_summary_text(text):
continue
cleaned.append(evidence_id)
seen.add(evidence_id)
if len(cleaned) >= max_items:
break
return cleaned
def infer_page_type(texts: List[str]) -> str:
joined = " ".join(texts)
if "历史搜索" in joined or "搜索" in joined:
if "热榜" in joined or "热搜" in joined or "猜你喜欢" in joined:
return "搜索和推荐内容页面"
return "搜索相关页面"
if any(term in joined for term in ["购物车", "下单", "购买", "商品", "价格", "优惠"]):
return "购物页面"
if any(term in joined for term in ["热榜", "榜单", "猜你喜欢", "推荐"]):
return "推荐内容页面"
if any(term in joined for term in ["评论", "点赞", "收藏", "分享"]):
return "内容互动页面"
if any(term in joined for term in ["我的", "订单", "设置", "账号", "会员"]):
return "个人或设置页面"
return "移动应用页面"
def quoted_join(values: List[str], limit: int) -> str:
return "、".join(f"“{value}”" for value in values[:limit])
def build_grounded_summary(row: Dict[str, Any]) -> str:
texts = ordered_ocr_texts(row, max_items=14)
dialog_texts = center_dialog_texts(row)
app = display_app_name(str(row.get("app") or ""))
page_type = infer_page_type(texts)
subject = f"{app}{page_type}" if app else page_type
if not texts:
return f"这张截图展示的是{subject},但当前没有识别到足够清晰的屏幕文字。"
if dialog_texts:
prompt = next((text for text in dialog_texts if any(term in text for term in ["确认", "是否", "删除", "商品吗"])), dialog_texts[0])
actions = [text for text in dialog_texts if text in {"取消", "确定", "删除", "关闭"}]
action_text = f",提供{quoted_join(actions, 4)}等按钮" if actions else ""
return f"这张截图显示{subject}上弹出确认对话框,提示“{prompt}{action_text};背景是购物车商品列表。"
primary = quoted_join(texts, 5)
summary = f"这张截图主要是{subject},屏幕上能看到{primary}等文字。"
extra = texts[5:10]
if extra:
summary += f" 下方还出现{quoted_join(extra, 5)}等条目。"
return summary
def merge_items(primary: List[Any], secondary: List[Any], max_items: int) -> List[Any]:
merged: List[Any] = []
seen = set()
for item in list(primary or []) + list(secondary or []):
if isinstance(item, dict):
key = json.dumps(item, ensure_ascii=False, sort_keys=True)
else:
key = safe_text(item)
if not key or key in seen:
continue
seen.add(key)
merged.append(item)
if len(merged) >= max_items:
break
return merged
def gui_ground_prediction(row: Dict[str, Any], pred_obj: Optional[Dict[str, Any]], args: argparse.Namespace) -> Dict[str, Any]:
pred = pred_obj if isinstance(pred_obj, dict) else {}
template = template_prediction(row, max_visible=args.max_visible_text)
grounded = {
"画面总结": safe_text(pred.get("画面总结") or pred.get("summary_zh") or pred.get("summary")),
"可见文字": pred.get("可见文字") or pred.get("visible_text") or [],
"互动数据": pred.get("互动数据") or pred.get("interaction_data") or [],
"功能入口": pred.get("功能入口") or pred.get("ui_functions") or [],
"关键证据": pred.get("关键证据") or pred.get("key_ui_clues") or pred.get("evidence") or [],
}
model_summary = safe_text(grounded.get("画面总结"))
if row.get("ocr_items") and (args.gui_summary_mode == "ocr" or (args.gui_summary_mode == "auto" and not model_summary)):
grounded["画面总结"] = build_grounded_summary(row)
if row.get("ocr_items"):
grounded["可见文字"] = ordered_ocr_texts(row, max_items=args.max_visible_text)
max_functions = int(getattr(args, "structured_max_functions", None) or 12)
gui_functions = collect_gui_functions(row, max_functions)
model_functions = clean_function_list(row, grounded.get("功能入口", []), max_functions)
if args.merge_ocr_functions:
grounded["功能入口"] = merge_items(
model_functions,
gui_functions,
max_functions,
)
elif not grounded.get("功能入口"):
grounded["功能入口"] = gui_functions
else:
grounded["功能入口"] = model_functions
if not grounded.get("互动数据"):
grounded["互动数据"] = template.get("互动数据", [])
function_evidence = []
for function in grounded.get("功能入口", []) or []:
if isinstance(function, dict):
function_evidence.extend(function.get("evidence_ids", []) or [])
evidence_candidates = function_evidence + list(grounded.get("关键证据", []) or [])
if not evidence_candidates:
evidence_candidates = list(template.get("关键证据", []) or [])
grounded["关键证据"] = clean_evidence_ids(row, evidence_candidates, 8)
if not grounded.get("画面总结"):
grounded["画面总结"] = template.get("画面总结", "")
return grounded
def parse_multipart(headers: Any, body: bytes) -> Dict[str, Any]:
content_type = headers.get("Content-Type", "")
match = re.search(r"boundary=(?P<q>\"?)([^\";]+)(?P=q)", content_type)
if not match:
raise ValueError("Missing multipart boundary.")
boundary = ("--" + match.group(2)).encode("utf-8")
fields: Dict[str, Any] = {}
for part in body.split(boundary):
part = part.strip(b"\r\n")
if not part or part == b"--":
continue
header_blob, _, value = part.partition(b"\r\n\r\n")
if not header_blob:
continue
header_text = header_blob.decode("utf-8", errors="replace")
disposition = ""
for line in header_text.split("\r\n"):
if line.lower().startswith("content-disposition:"):
disposition = line.split(":", 1)[1]
break
name_match = re.search(r'name="([^"]+)"', disposition)
if not name_match:
continue
field_name = name_match.group(1)
filename_match = re.search(r'filename="([^"]*)"', disposition)
if filename_match:
fields[field_name] = {
"filename": filename_match.group(1),
"content": value.rstrip(b"\r\n"),
}
else:
fields[field_name] = value.rstrip(b"\r\n").decode("utf-8", errors="replace")
return fields
class RichGuiPredictor:
def __init__(self, args: argparse.Namespace):
self.args = args
self.device = torch.device(args.device if args.device else ("cuda" if torch.cuda.is_available() else "cpu"))
self.model = None
self.tokenizer = None
self.image_processor = None
self.ckpt_args = None
self.collator = None
self.lock = threading.Lock()
def load_model(self) -> None:
if self.args.template_only or self.model is not None:
return
if not self.args.checkpoint:
raise FileNotFoundError("Checkpoint not set. Train a natural multimodal checkpoint first or pass --checkpoint.")
checkpoint = Path(self.args.checkpoint)
if not checkpoint.exists():
raise FileNotFoundError(f"Checkpoint not found: {checkpoint}")
self.model, self.tokenizer, self.image_processor, self.ckpt_args = load_rich_checkpoint(str(checkpoint), self.device)
self.apply_runtime_args(self.ckpt_args)
self.collator = RichCollator(self.tokenizer, self.image_processor, self.ckpt_args)
def apply_runtime_args(self, ckpt_args: argparse.Namespace) -> None:
optional_names = [
"generation_no_repeat_ngram_size",
"generation_repetition_penalty",
"generation_block_extra_ids",
"generation_block_title_prefix",
"generation_force_json_start",
"context_summary_repair",
"canonicalize_targets",
"drop_bare_search_functions",
"structured_function_threshold",
"structured_search_threshold",
"structured_max_functions",
"structured_strict_search_candidates",
"structured_evidence_threshold",
"structured_max_evidence",
"structured_evidence_fallback_top1",
]
for name in optional_names:
value = getattr(self.args, name, None)
if value is not None and value != "":
setattr(ckpt_args, name, value)
for name in ["structured_function_mode", "structured_evidence_mode"]:
value = getattr(self.args, name, "")
if value:
setattr(ckpt_args, name, value)
ckpt_args.num_workers = 0
def save_image(self, upload: Dict[str, Any]) -> Path:
raw = upload.get("content") or b""
if not raw:
raise ValueError("Uploaded image is empty.")
filename = safe_upload_name(upload.get("filename", "upload.png"))
suffix = Path(filename).suffix.lower()
if suffix and suffix not in ALLOWED_IMAGE_EXTS:
raise ValueError(f"Unsupported image type: {suffix}")
upload_dir = Path(self.args.upload_dir)
upload_dir.mkdir(parents=True, exist_ok=True)
with Image.open(BytesIO(raw)) as img:
image = img.convert("RGB")
stamp = time.strftime("%Y%m%d_%H%M%S")
upload_path = upload_dir / f"{stamp}_{int(time.time() * 1000) % 1000:03d}_{Path(filename).stem}.png"
image.save(upload_path)
return upload_path
def build_row(self, image_path: Path, app: str, focus: str, ocr_engine: str) -> Tuple[Dict[str, Any], Optional[str]]:
model_instruction = safe_text(focus) if self.args.use_focus_in_model else ""
row: Dict[str, Any] = {
"screen_id": image_path.stem,
"image_path": str(image_path),
"app": safe_text(app) or "移动应用",
"instruction": model_instruction,
"display_focus": safe_text(focus),
"target": {
"summary_zh": "",
"visible_text": [],
"interaction_data": [],
"ui_functions": [],
"key_ui_clues": [],
},
"ocr_items": [],
"ui_items": [],
"weak_evidence_ids": [],
}
ocr_error = None
if ocr_engine != "none":
try:
cache_args = argparse.Namespace(ocr_engine=ocr_engine, ocr_lang=self.args.ocr_lang)
image_sha = sha256_file(image_path)
ocr_items = load_ocr_items(image_path, image_sha, Path(self.args.ocr_cache_dir), cache_args)
ocr_items = filter_ocr_items(ocr_items, self.args.min_ocr_conf)[: self.args.max_ocr_items]
row["ocr_items"] = ocr_items
row["ui_items"] = build_ocr_ui_items(row, ocr_items, self.args.max_ui_items)
row["weak_evidence_ids"] = [item.get("id") for item in row["ui_items"][:8] if item.get("id")]
except Exception as exc: # OCR should not prevent image-only inference.
ocr_error = str(exc)
return row, ocr_error
@torch.no_grad()
def summarize(self, image_path: Path, app: str, focus: str, ocr_engine: str) -> Dict[str, Any]:
start = time.perf_counter()
row, ocr_error = self.build_row(image_path, app, focus, ocr_engine)
if self.args.template_only:
pred = template_prediction(row, max_visible=self.args.max_visible_text)
pred = gui_ground_prediction(row, pred, self.args)
result = row_result(
row=row,
raw_text=json.dumps(pred, ensure_ascii=False),
pred_obj=pred,
json_valid=True,
evidence_scores=None,
allow_template_fallback=False,
source="template",
)
else:
with self.lock:
self.load_model()
batch = self.collator([row])
batch = move_batch(batch, self.device)
text = self.model.generate_text(
batch,
self.tokenizer,
num_beams=self.args.num_beams,
max_new_tokens=self.args.max_new_tokens,
)[0]
_, _, elem_tokens, elem_key_padding = self.model.build_memory(batch)
masks = (~elem_key_padding)[0].detach().cpu().numpy()
evidence_head = torch.sigmoid(self.model.evidence_head(elem_tokens).squeeze(-1))[0].detach().cpu()
function_head = torch.sigmoid(self.model.ui_function_head(elem_tokens).squeeze(-1))[0].detach().cpu()
search_head = torch.sigmoid(self.model.search_function_head(elem_tokens).squeeze(-1))[0].detach().cpu()
output_is_summary = target_schema_is_summary(getattr(self.ckpt_args, "target_schema", "zh"))
output_is_natural = target_schema_is_natural_text(getattr(self.ckpt_args, "target_schema", "zh"))
if output_is_summary:
pred_obj = prediction_from_summary(row, text)
ok = True
elif output_is_natural:
pred_obj = natural_prediction_from_text(text)
ok = bool(pred_obj.get("画面总结"))
else:
pred_obj, ok = safe_json_loads(text)
if pred_obj is None:
elements = row.get("ui_items", []) or []
ranked = torch.argsort(evidence_head, descending=True).tolist()
top_ids = []
for idx in ranked:
if idx < len(elements) and idx < len(masks) and masks[idx]:
evidence_id = safe_text(elements[idx].get("id") or elements[idx].get("ocr_id"))
if evidence_id:
top_ids.append(evidence_id)
if len(top_ids) >= self.args.top_k_clues:
break
pred_obj = {"关键证据": top_ids}
if self.args.context_summary_repair:
pred_obj, _ = repair_prediction_with_context(row, pred_obj)
ok = True
pred_obj = apply_structured_function_predictions(row, pred_obj, function_head, search_head, self.ckpt_args)
pred_obj = apply_structured_evidence_predictions(row, pred_obj, evidence_head, self.ckpt_args)
pred_obj = gui_ground_prediction(row, pred_obj, self.args)
evidence_scores: Dict[str, float] = {}
for idx, elem in enumerate(row.get("ui_items", []) or []):
if idx < len(masks) and masks[idx]:
evidence_id = safe_text(elem.get("id") or elem.get("ocr_id"))
if evidence_id:
evidence_scores[evidence_id] = float(evidence_head[idx])
result = row_result(
row=row,
raw_text=text,
pred_obj=pred_obj,
json_valid=ok,
evidence_scores=evidence_scores,
allow_template_fallback=self.args.allow_template_fallback,
source="model",
)
result["ocr_count"] = len(row.get("ocr_items", []) or [])
result["ui_item_count"] = len(row.get("ui_items", []) or [])
result["ocr_error"] = ocr_error
result["image_url"] = f"/uploads/{image_path.name}"
result["focus"] = safe_text(focus)
if self.args.template_only:
result["model_output_mode"] = "template"
elif self.ckpt_args is not None and target_schema_is_summary(getattr(self.ckpt_args, "target_schema", "zh")):
result["model_output_mode"] = "summary"
elif self.ckpt_args is not None and target_schema_is_natural_text(getattr(self.ckpt_args, "target_schema", "zh")):
result["model_output_mode"] = "natural_text"
else:
result["model_output_mode"] = "json"
result["display_mode"] = self.args.gui_summary_mode if row.get("ocr_items") else "model"
result["elapsed_sec"] = round(time.perf_counter() - start, 3)
return result
class RichGuiHandler(BaseHTTPRequestHandler):
server_version = "RichGui/1.0"
def send_payload(self, status: int, body: bytes, content_type: str) -> None:
self.send_response(status)
self.send_header("Content-Type", content_type)
self.send_header("Cache-Control", "no-store, no-cache, must-revalidate, max-age=0")
self.send_header("Pragma", "no-cache")
self.send_header("Content-Length", str(len(body)))
self.end_headers()
self.wfile.write(body)
def do_GET(self) -> None: # noqa: N802
parsed = urlparse(self.path)
if parsed.path in {"/", "/index.html"}:
self.send_payload(HTTPStatus.OK, HTML_PAGE.encode("utf-8"), "text/html; charset=utf-8")
return
if parsed.path == "/api/health":
status, body, content_type = json_bytes(
{
"ok": True,
"template_only": self.server.predictor.args.template_only,
"checkpoint": None if self.server.predictor.args.template_only else self.server.predictor.args.checkpoint,
"structured_function_mode": self.server.predictor.args.structured_function_mode,
"structured_function_threshold": self.server.predictor.args.structured_function_threshold,
"structured_search_threshold": self.server.predictor.args.structured_search_threshold,
"structured_evidence_mode": self.server.predictor.args.structured_evidence_mode,
"structured_evidence_threshold": self.server.predictor.args.structured_evidence_threshold,
}
)
self.send_payload(status, body, content_type)
return
if parsed.path.startswith("/uploads/"):
filename = safe_upload_name(unquote(parsed.path.removeprefix("/uploads/")))
path = Path(self.server.predictor.args.upload_dir) / filename
if path.exists() and path.is_file():
content_type = mimetypes.guess_type(path.name)[0] or "application/octet-stream"
self.send_payload(HTTPStatus.OK, path.read_bytes(), content_type)
else:
status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND)
self.send_payload(status, body, content_type)
return
status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND)
self.send_payload(status, body, content_type)
def do_POST(self) -> None: # noqa: N802
parsed = urlparse(self.path)
if parsed.path != "/api/summarize":
status, body, content_type = json_bytes({"error": "Not found"}, HTTPStatus.NOT_FOUND)
self.send_payload(status, body, content_type)
return
try:
length = int(self.headers.get("Content-Length", "0"))
if length > self.server.predictor.args.max_upload_mb * 1024 * 1024:
raise ValueError(f"Image is larger than {self.server.predictor.args.max_upload_mb} MB.")
fields = parse_multipart(self.headers, self.rfile.read(length))
upload = fields.get("image")
if not isinstance(upload, dict):
raise ValueError("Missing image field.")
image_path = self.server.predictor.save_image(upload)
app = safe_text(fields.get("app")) or "移动应用"
focus = safe_text(fields.get("focus")) or safe_text(fields.get("instruction"))
ocr_engine = safe_text(fields.get("ocr_engine")) or self.server.predictor.args.ocr_engine
if ocr_engine not in {"none", "paddleocr"}:
ocr_engine = self.server.predictor.args.ocr_engine
result = self.server.predictor.summarize(image_path, app, focus, ocr_engine)
status, body, content_type = json_bytes(result, HTTPStatus.OK)
except Exception as exc:
status, body, content_type = json_bytes({"error": str(exc)}, HTTPStatus.BAD_REQUEST)
self.send_payload(status, body, content_type)
def log_message(self, fmt: str, *args: Any) -> None:
if not self.server.predictor.args.quiet:
super().log_message(fmt, *args)
class RichGuiServer(ThreadingHTTPServer):
def __init__(self, server_address: Tuple[str, int], handler_class: Any, predictor: RichGuiPredictor):
super().__init__(server_address, handler_class)
self.predictor = predictor
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Start a local GUI for rich screenshot summarization.",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--checkpoint", default=DEFAULT_CHECKPOINT)
parser.add_argument("--template_only", action="store_true")
parser.add_argument("--host", default="127.0.0.1")
parser.add_argument("--port", type=int, default=7860)
parser.add_argument("--open_browser", type=str_to_bool, default=True)
parser.add_argument("--quiet", action="store_true")
parser.add_argument("--upload_dir", default="outputs/rich_gui/uploads")
parser.add_argument("--ocr_cache_dir", default="data/rich_cmgui/cache/gui_ocr")
parser.add_argument("--ocr_engine", choices=["none", "paddleocr"], default="paddleocr")
parser.add_argument("--ocr_lang", default="ch")
parser.add_argument("--min_ocr_conf", type=float, default=0.5)
parser.add_argument("--max_ocr_items", type=int, default=120)
parser.add_argument("--max_ui_items", type=int, default=48)
parser.add_argument("--max_upload_mb", type=int, default=20)
parser.add_argument("--device", default="")
parser.add_argument("--num_beams", type=int, default=1)
parser.add_argument("--max_new_tokens", type=int, default=256)
parser.add_argument("--top_k_clues", type=int, default=5)
parser.add_argument("--max_visible_text", type=int, default=12)
parser.add_argument("--gui_summary_mode", choices=["model", "ocr", "auto"], default="model")
parser.add_argument("--merge_ocr_functions", type=str_to_bool, default=True)
parser.add_argument("--use_focus_in_model", type=str_to_bool, default=False)
parser.add_argument("--generation_no_repeat_ngram_size", type=int, default=3)
parser.add_argument("--generation_repetition_penalty", type=float, default=1.1)
parser.add_argument("--generation_block_extra_ids", type=str_to_bool, default=True)
parser.add_argument("--generation_block_title_prefix", type=str_to_bool, default=True)
parser.add_argument("--generation_force_json_start", type=str_to_bool, default=False)
parser.add_argument("--context_summary_repair", type=str_to_bool, default=None)
parser.add_argument("--canonicalize_targets", type=str_to_bool, default=None)
parser.add_argument("--drop_bare_search_functions", type=str_to_bool, default=None)
parser.add_argument("--structured_function_mode", choices=["", "decoder", "heads"], default="heads")
parser.add_argument("--structured_function_threshold", type=float, default=DEFAULT_FUNCTION_THRESHOLD)
parser.add_argument("--structured_search_threshold", type=float, default=DEFAULT_SEARCH_THRESHOLD)
parser.add_argument("--structured_max_functions", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS)
parser.add_argument("--structured_strict_search_candidates", type=str_to_bool, default=None)
parser.add_argument("--structured_evidence_mode", choices=["", "decoder", "heads"], default="heads")
parser.add_argument("--structured_evidence_threshold", type=float, default=DEFAULT_EVIDENCE_THRESHOLD)
parser.add_argument("--structured_max_evidence", type=int, default=DEFAULT_MAX_STRUCTURED_ITEMS)
parser.add_argument("--structured_evidence_fallback_top1", type=str_to_bool, default=False)
parser.add_argument("--allow_template_fallback", type=str_to_bool, default=False)
args = parser.parse_args()
if not args.template_only and not args.checkpoint:
checkpoint = find_latest_rich_checkpoint()
if checkpoint is None:
raise FileNotFoundError("No stage3/stage4 rich checkpoint found. Train first or pass --checkpoint.")
args.checkpoint = str(checkpoint)
return args
def main() -> None:
args = parse_args()
predictor = RichGuiPredictor(args)
last_error: Optional[OSError] = None
server: Optional[RichGuiServer] = None
for port in range(args.port, args.port + 20):
try:
server = RichGuiServer((args.host, port), RichGuiHandler, predictor)
break
except OSError as exc:
last_error = exc
if server is None:
raise RuntimeError(f"Could not bind a local port starting at {args.port}: {last_error}")
actual_port = server.server_address[1]
url = f"http://{args.host}:{actual_port}/"
print(
json.dumps(
{
"url": url,
"template_only": args.template_only,
"checkpoint": None if args.template_only else args.checkpoint,
"num_beams": args.num_beams,
"max_new_tokens": args.max_new_tokens,
"structured_function_mode": args.structured_function_mode,
"structured_function_threshold": args.structured_function_threshold,
"structured_search_threshold": args.structured_search_threshold,
"structured_evidence_mode": args.structured_evidence_mode,
"structured_evidence_threshold": args.structured_evidence_threshold,
},
ensure_ascii=False,
)
)
if args.open_browser:
threading.Timer(0.5, lambda: webbrowser.open(url)).start()
server.serve_forever()
if __name__ == "__main__":
main()