import gradio as gr import torch import re import unicodedata import torch.nn.functional as F from transformers import AutoTokenizer, AutoModelForMaskedLM from fastapi import FastAPI, Request from fastapi.responses import JSONResponse from fastapi.middleware.cors import CORSMiddleware # =============================== # 1. 模型加载 # =============================== print("正在加载模型...") model_name = "ethanyt/guwenbert-base" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForMaskedLM.from_pretrained(model_name) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) model.eval() embedding_matrix = model.get_input_embeddings().weight print(f"模型加载完成,使用设备: {device}") # =============================== # 2. 虚词表 # =============================== stop_chars = set([ "之", "其", "而", "于", "以", "也", "兮", "乎", "者", "矣", "焉", "耳", "哉" ]) # =============================== # 3. 判断字符是否合法 # =============================== def is_valid_char(ch): if not isinstance(ch, str) or len(ch) != 1: return False if ch.strip() == "": return False if unicodedata.category(ch).startswith("P"): return False if tokenizer.convert_tokens_to_ids(ch) == tokenizer.unk_token_id: return False if ch in stop_chars: return False return True # =============================== # 4. 单字分析 # =============================== def analyze_char(sentence, idx, top_k=5): chars = list(sentence) if idx < 0 or idx >= len(chars): return {"error": "索引超出范围"} original_char = chars[idx] if not is_valid_char(original_char): return {"error": f"字符 '{original_char}' 不符合分析条件"} chars[idx] = tokenizer.mask_token masked_sentence = "".join(chars) inputs = tokenizer(masked_sentence, return_tensors="pt") inputs = {k: v.to(device) for k, v in inputs.items()} mask_token_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1] if len(mask_token_index) == 0: return {"error": "未找到MASK标记"} with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits mask_logits = logits[0, mask_token_index[0], :] probs = torch.softmax(mask_logits, dim=-1) sorted_probs, sorted_indices = torch.sort(probs, descending=True) filtered_tokens = [] filtered_probs = [] filtered_rank = 0 original_rank = None for prob, idx_token in zip(sorted_probs, sorted_indices): token = tokenizer.convert_ids_to_tokens(idx_token.item()) if not is_valid_char(token): continue filtered_rank += 1 if token == original_char and original_rank is None: original_rank = filtered_rank if len(filtered_tokens) < top_k: filtered_tokens.append(token) filtered_probs.append(prob.item()) if len(filtered_tokens) >= top_k and original_rank is not None: break if original_rank is None: original_rank = 9999 original_id = tokenizer.convert_tokens_to_ids(original_char) top1_id = tokenizer.convert_tokens_to_ids(filtered_tokens[0]) if filtered_tokens else original_id orig_vec = embedding_matrix[original_id] top1_vec = embedding_matrix[top1_id] cos_sim = F.cosine_similarity(orig_vec, top1_vec, dim=0).item() original_prob = probs[original_id].item() return { "sentence": sentence, "original_char": original_char, "position": idx, "rank": original_rank, "prob": original_prob, "top_k": list(zip(filtered_tokens, filtered_probs)), "cos_sim": cos_sim } # =============================== # 5. API 接口 # =============================== def api_predict(sentence, mask_index, top_k=5): try: mask_index = int(mask_index) top_k = int(top_k) result = analyze_char(sentence, mask_index, top_k) return result except Exception as e: return {"error": str(e)} # =============================== # 6. 创建 FastAPI app # =============================== app = FastAPI() # 添加 CORS 中间件 app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # 添加自定义 API 端点 @app.post("/api/predict") async def predict_endpoint(request: Request): try: data = await request.json() result = api_predict( data.get('sentence'), data.get('mask_index'), data.get('top_k', 5) ) return JSONResponse(content=result) except Exception as e: return JSONResponse(content={"error": str(e)}, status_code=500) # =============================== # 7. Gradio 界面 # =============================== with gr.Blocks(title="鵩鸟赋 - AI古文分析", theme=gr.themes.Soft()) as demo: gr.Markdown("# 鵩鸟赋 - AI古文分析工具") gr.Markdown("基于 GuwenBERT 的 Masked Language Model,自动过滤标点和虚词") with gr.Row(): with gr.Column(): input_sentence = gr.Textbox( label="句子", placeholder="输入要分析的句子", value="夫祸之与福兮何异纠纆" ) input_position = gr.Number(label="字符位置(从0开始)", value=0, precision=0) input_topk = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Top K") btn_predict = gr.Button("开始分析", variant="primary") with gr.Column(): output_result = gr.JSON(label="预测结果") gr.Examples( examples=[ ["夫祸之与福兮何异纠纆", 0, 5], ["万物变化兮固无休息", 3, 5], ["达人大观兮物无不可", 4, 5] ], inputs=[input_sentence, input_position, input_topk] ) btn_predict.click( fn=api_predict, inputs=[input_sentence, input_position, input_topk], outputs=output_result ) # =============================== # 8. 挂载 Gradio 到 FastAPI # =============================== app = gr.mount_gradio_app(app, demo, path="/") # =============================== # 9. 启动说明 # =============================== if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)