funiaofu / app.py
lungfish's picture
Upload 2 files
6ee4a0a verified
import gradio as gr
import torch
import re
import unicodedata
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForMaskedLM
from fastapi import FastAPI, Request
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
# ===============================
# 1. 模型加载
# ===============================
print("正在加载模型...")
model_name = "ethanyt/guwenbert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForMaskedLM.from_pretrained(model_name)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()
embedding_matrix = model.get_input_embeddings().weight
print(f"模型加载完成,使用设备: {device}")
# ===============================
# 2. 虚词表
# ===============================
stop_chars = set([
"之", "其", "而", "于", "以", "也", "兮", "乎", "者", "矣", "焉", "耳", "哉"
])
# ===============================
# 3. 判断字符是否合法
# ===============================
def is_valid_char(ch):
if not isinstance(ch, str) or len(ch) != 1:
return False
if ch.strip() == "":
return False
if unicodedata.category(ch).startswith("P"):
return False
if tokenizer.convert_tokens_to_ids(ch) == tokenizer.unk_token_id:
return False
if ch in stop_chars:
return False
return True
# ===============================
# 4. 单字分析
# ===============================
def analyze_char(sentence, idx, top_k=5):
chars = list(sentence)
if idx < 0 or idx >= len(chars):
return {"error": "索引超出范围"}
original_char = chars[idx]
if not is_valid_char(original_char):
return {"error": f"字符 '{original_char}' 不符合分析条件"}
chars[idx] = tokenizer.mask_token
masked_sentence = "".join(chars)
inputs = tokenizer(masked_sentence, return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}
mask_token_index = (inputs["input_ids"] == tokenizer.mask_token_id).nonzero(as_tuple=True)[1]
if len(mask_token_index) == 0:
return {"error": "未找到MASK标记"}
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
mask_logits = logits[0, mask_token_index[0], :]
probs = torch.softmax(mask_logits, dim=-1)
sorted_probs, sorted_indices = torch.sort(probs, descending=True)
filtered_tokens = []
filtered_probs = []
filtered_rank = 0
original_rank = None
for prob, idx_token in zip(sorted_probs, sorted_indices):
token = tokenizer.convert_ids_to_tokens(idx_token.item())
if not is_valid_char(token):
continue
filtered_rank += 1
if token == original_char and original_rank is None:
original_rank = filtered_rank
if len(filtered_tokens) < top_k:
filtered_tokens.append(token)
filtered_probs.append(prob.item())
if len(filtered_tokens) >= top_k and original_rank is not None:
break
if original_rank is None:
original_rank = 9999
original_id = tokenizer.convert_tokens_to_ids(original_char)
top1_id = tokenizer.convert_tokens_to_ids(filtered_tokens[0]) if filtered_tokens else original_id
orig_vec = embedding_matrix[original_id]
top1_vec = embedding_matrix[top1_id]
cos_sim = F.cosine_similarity(orig_vec, top1_vec, dim=0).item()
original_prob = probs[original_id].item()
return {
"sentence": sentence,
"original_char": original_char,
"position": idx,
"rank": original_rank,
"prob": original_prob,
"top_k": list(zip(filtered_tokens, filtered_probs)),
"cos_sim": cos_sim
}
# ===============================
# 5. API 接口
# ===============================
def api_predict(sentence, mask_index, top_k=5):
try:
mask_index = int(mask_index)
top_k = int(top_k)
result = analyze_char(sentence, mask_index, top_k)
return result
except Exception as e:
return {"error": str(e)}
# ===============================
# 6. 创建 FastAPI app
# ===============================
app = FastAPI()
# 添加 CORS 中间件
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 添加自定义 API 端点
@app.post("/api/predict")
async def predict_endpoint(request: Request):
try:
data = await request.json()
result = api_predict(
data.get('sentence'),
data.get('mask_index'),
data.get('top_k', 5)
)
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(content={"error": str(e)}, status_code=500)
# ===============================
# 7. Gradio 界面
# ===============================
with gr.Blocks(title="鵩鸟赋 - AI古文分析", theme=gr.themes.Soft()) as demo:
gr.Markdown("# 鵩鸟赋 - AI古文分析工具")
gr.Markdown("基于 GuwenBERT 的 Masked Language Model,自动过滤标点和虚词")
with gr.Row():
with gr.Column():
input_sentence = gr.Textbox(
label="句子",
placeholder="输入要分析的句子",
value="夫祸之与福兮何异纠纆"
)
input_position = gr.Number(label="字符位置(从0开始)", value=0, precision=0)
input_topk = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Top K")
btn_predict = gr.Button("开始分析", variant="primary")
with gr.Column():
output_result = gr.JSON(label="预测结果")
gr.Examples(
examples=[
["夫祸之与福兮何异纠纆", 0, 5],
["万物变化兮固无休息", 3, 5],
["达人大观兮物无不可", 4, 5]
],
inputs=[input_sentence, input_position, input_topk]
)
btn_predict.click(
fn=api_predict,
inputs=[input_sentence, input_position, input_topk],
outputs=output_result
)
# ===============================
# 8. 挂载 Gradio 到 FastAPI
# ===============================
app = gr.mount_gradio_app(app, demo, path="/")
# ===============================
# 9. 启动说明
# ===============================
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)