| import gradio as gr
|
| import torch
|
| import re
|
| import unicodedata
|
| import torch.nn.functional as F
|
| from transformers import AutoTokenizer, AutoModelForMaskedLM
|
| from fastapi import FastAPI, Request
|
| from fastapi.responses import JSONResponse
|
| from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
|
|
|
|
|
| print("正在加载模型...")
|
| model_name = "ethanyt/guwenbert-base"
|
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| model = AutoModelForMaskedLM.from_pretrained(model_name)
|
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| model.to(device)
|
| model.eval()
|
|
|
| embedding_matrix = model.get_input_embeddings().weight
|
| print(f"模型加载完成,使用设备: {device}")
|
|
|
|
|
|
|
|
|
| stop_chars = set([
|
| "之", "其", "而", "于", "以", "也", "兮", "乎", "者", "矣", "焉", "耳", "哉"
|
| ])
|
|
|
|
|
|
|
|
|
|
|
| def is_valid_char(ch):
|
| if not isinstance(ch, str) or len(ch) != 1:
|
| return False
|
| if ch.strip() == "":
|
| return False
|
| if unicodedata.category(ch).startswith("P"):
|
| return False
|
| if tokenizer.convert_tokens_to_ids(ch) == tokenizer.unk_token_id:
|
| return False
|
| if ch in stop_chars:
|
| return False
|
| return True
|
|
|
|
|
|
|
|
|
|
|
| def analyze_char(sentence, idx, top_k=5):
|
| chars = list(sentence)
|
|
|
| if idx < 0 or idx >= len(chars):
|
| return {"error": "索引超出范围"}
|
|
|
| original_char = chars[idx]
|
|
|
| if not is_valid_char(original_char):
|
| return {"error": f"字符 '{original_char}' 不符合分析条件"}
|
|
|
| chars[idx] = tokenizer.mask_token
|
| masked_sentence = "".join(chars)
|
|
|
| inputs = tokenizer(masked_sentence, return_tensors="pt")
|
| inputs = {k: v.to(device) for k, v in inputs.items()}
|
|
|
| mask_token_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
|
|
|
| if len(mask_token_index) == 0:
|
| return {"error": "未找到MASK标记"}
|
|
|
| with torch.no_grad():
|
| outputs = model(**inputs)
|
|
|
| logits = outputs.logits
|
| mask_logits = logits[0, mask_token_index[0], :]
|
| probs = torch.softmax(mask_logits, dim=-1)
|
|
|
| sorted_probs, sorted_indices = torch.sort(probs, descending=True)
|
|
|
| filtered_tokens = []
|
| filtered_probs = []
|
| filtered_rank = 0
|
| original_rank = None
|
|
|
| for prob, idx_token in zip(sorted_probs, sorted_indices):
|
| token = tokenizer.convert_ids_to_tokens(idx_token.item())
|
|
|
| if not is_valid_char(token):
|
| continue
|
|
|
| filtered_rank += 1
|
|
|
| if token == original_char and original_rank is None:
|
| original_rank = filtered_rank
|
|
|
| if len(filtered_tokens) < top_k:
|
| filtered_tokens.append(token)
|
| filtered_probs.append(prob.item())
|
|
|
| if len(filtered_tokens) >= top_k and original_rank is not None:
|
| break
|
|
|
| if original_rank is None:
|
| original_rank = 9999
|
|
|
| original_id = tokenizer.convert_tokens_to_ids(original_char)
|
| top1_id = tokenizer.convert_tokens_to_ids(filtered_tokens[0]) if filtered_tokens else original_id
|
|
|
| orig_vec = embedding_matrix[original_id]
|
| top1_vec = embedding_matrix[top1_id]
|
|
|
| cos_sim = F.cosine_similarity(orig_vec, top1_vec, dim=0).item()
|
| original_prob = probs[original_id].item()
|
|
|
| return {
|
| "sentence": sentence,
|
| "original_char": original_char,
|
| "position": idx,
|
| "rank": original_rank,
|
| "prob": original_prob,
|
| "top_k": list(zip(filtered_tokens, filtered_probs)),
|
| "cos_sim": cos_sim
|
| }
|
|
|
|
|
|
|
|
|
|
|
| def api_predict(sentence, mask_index, top_k=5):
|
| try:
|
| mask_index = int(mask_index)
|
| top_k = int(top_k)
|
| result = analyze_char(sentence, mask_index, top_k)
|
| return result
|
| except Exception as e:
|
| return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
| app = FastAPI()
|
|
|
|
|
| app.add_middleware(
|
| CORSMiddleware,
|
| allow_origins=["*"],
|
| allow_credentials=True,
|
| allow_methods=["*"],
|
| allow_headers=["*"],
|
| )
|
|
|
|
|
|
|
| @app.post("/api/predict")
|
| async def predict_endpoint(request: Request):
|
| try:
|
| data = await request.json()
|
| result = api_predict(
|
| data.get('sentence'),
|
| data.get('mask_index'),
|
| data.get('top_k', 5)
|
| )
|
| return JSONResponse(content=result)
|
| except Exception as e:
|
| return JSONResponse(content={"error": str(e)}, status_code=500)
|
|
|
|
|
|
|
|
|
|
|
| with gr.Blocks(title="鵩鸟赋 - AI古文分析", theme=gr.themes.Soft()) as demo:
|
| gr.Markdown("# 鵩鸟赋 - AI古文分析工具")
|
| gr.Markdown("基于 GuwenBERT 的 Masked Language Model,自动过滤标点和虚词")
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| input_sentence = gr.Textbox(
|
| label="句子",
|
| placeholder="输入要分析的句子",
|
| value="夫祸之与福兮何异纠纆"
|
| )
|
| input_position = gr.Number(label="字符位置(从0开始)", value=0, precision=0)
|
| input_topk = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Top K")
|
| btn_predict = gr.Button("开始分析", variant="primary")
|
|
|
| with gr.Column():
|
| output_result = gr.JSON(label="预测结果")
|
|
|
| gr.Examples(
|
| examples=[
|
| ["夫祸之与福兮何异纠纆", 0, 5],
|
| ["万物变化兮固无休息", 3, 5],
|
| ["达人大观兮物无不可", 4, 5]
|
| ],
|
| inputs=[input_sentence, input_position, input_topk]
|
| )
|
|
|
| btn_predict.click(
|
| fn=api_predict,
|
| inputs=[input_sentence, input_position, input_topk],
|
| outputs=output_result
|
| )
|
|
|
|
|
|
|
|
|
| app = gr.mount_gradio_app(app, demo, path="/")
|
|
|
|
|
|
|
|
|
| if __name__ == "__main__":
|
| import uvicorn
|
|
|
| uvicorn.run(app, host="0.0.0.0", port=7860) |