Spaces:
Running
Running
| import os | |
| import torch | |
| import gradio as gr | |
| from transformers import pipeline, AutoModelForQuestionAnswering, AutoTokenizer, AutoModelForSequenceClassification | |
| # 1. 基础环境设置 | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| os.environ["TF_USE_LEGACY_KERAS"] = "1" | |
| print("正在初始化 AI 模型...") | |
| # ---------------------------------------------------------------- | |
| # 1. 完形填空功能 | |
| # ---------------------------------------------------------------- | |
| try: | |
| unmasker = pipeline('fill-mask', model='bert-base-chinese', device=-1) | |
| except Exception as e: | |
| print(f"Fill-Mask 加载警告: {e}") | |
| unmasker = None | |
| def fill_mask_ai(text): | |
| if unmasker is None: return "模型加载出错" | |
| if "[MASK]" not in text: return "⚠️ 错误:请在句子中包含 [MASK] 符号" | |
| try: | |
| results = unmasker(text) | |
| output_str = "" | |
| for idx, res in enumerate(results): | |
| score = res['score'] * 100 | |
| token = res['token_str'] | |
| output_str += f"第 {idx+1} 名: 【{token}】 (置信度: {score:.1f}%)\n" | |
| return output_str | |
| except Exception as e: return f"运行出错: {e}" | |
| # ---------------------------------------------------------------- | |
| # 2. 阅读理解功能 (手动加载防报错) | |
| # ---------------------------------------------------------------- | |
| qa_model_name = "uer/roberta-base-chinese-extractive-qa" | |
| try: | |
| print("正在手动加载 QA 模型...") | |
| qa_model = AutoModelForQuestionAnswering.from_pretrained(qa_model_name) | |
| qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) | |
| print("QA 模型加载成功!(原生模式)") | |
| except Exception as e: | |
| print(f"QA 模型加载失败: {e}") | |
| qa_model = None | |
| qa_tokenizer = None | |
| def reading_comprehension(context, question): | |
| if qa_model is None: return "QA 模型未能成功加载" | |
| if not context or not question: return "请填写完整的文章和问题。" | |
| try: | |
| inputs = qa_tokenizer(question, context, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = qa_model(**inputs) | |
| answer_start_index = outputs.start_logits.argmax() | |
| answer_end_index = outputs.end_logits.argmax() | |
| predict_answer_tokens = inputs.input_ids[0, answer_start_index : answer_end_index + 1] | |
| answer = qa_tokenizer.decode(predict_answer_tokens, skip_special_tokens=True) | |
| if not answer.strip(): | |
| return "(未能找到明确答案)" | |
| start_score = outputs.start_logits.max().item() | |
| end_score = outputs.end_logits.max().item() | |
| confidence = (start_score + end_score) / 2.0 | |
| return f"🤔 AI 的回答:{answer}\n(模型得分: {confidence:.2f})" | |
| except Exception as e: | |
| return f"推理出错: {e}" | |
| # ---------------------------------------------------------------- | |
| # 3. 情感分析功能 (修正了之前可能弄反的逻辑) | |
| # ---------------------------------------------------------------- | |
| senti_model_name = "uer/roberta-base-finetuned-dianping-chinese" | |
| try: | |
| print("正在手动加载情感分析模型...") | |
| senti_model = AutoModelForSequenceClassification.from_pretrained(senti_model_name) | |
| senti_tokenizer = AutoTokenizer.from_pretrained(senti_model_name) | |
| sentiment_pipeline = pipeline('text-classification', model=senti_model, tokenizer=senti_tokenizer, device=-1) | |
| except Exception as e: | |
| print(f"情感分析加载失败: {e}") | |
| sentiment_pipeline = None | |
| def sentiment_analysis(text): | |
| if sentiment_pipeline is None: return "模型未能加载" | |
| try: | |
| result = sentiment_pipeline(text)[0] | |
| label_text = str(result['label']).lower() # 转为小写,防止大小写干扰 | |
| score = result['score'] * 100 | |
| # Debug: 在后台打印出到底是什么标签,方便调试 | |
| print(f"输入: {text} | 原始标签: {label_text} | 分数: {score}") | |
| # 【修正后的严格判断逻辑】 | |
| # 这个模型通常输出 "positive (5 stars)" 或 "negative (1 star)" | |
| # 之前的逻辑可能对 "labels_0/1" 判断有误,现在改为关键词强匹配 | |
| if 'positive' in label_text: | |
| label = "😊 正面/积极" | |
| elif '5 star' in label_text or '4 star' in label_text: | |
| label = "😊 正面/积极" | |
| elif 'neutral' in label_text: # 极少数情况有中性 | |
| label = "😐 中性/平和" | |
| else: | |
| # 剩下的 negative, 1 star, 2 star, 3 star 全都算消极 | |
| label = "😡 负面/消极" | |
| return f"分析结果:{label}\n强度:{score:.1f}%" | |
| except Exception as e: return f"分析出错: {e}" | |
| # ---------------------------------------------------------------- | |
| # 4. 界面构建 | |
| # ---------------------------------------------------------------- | |
| custom_css = """ | |
| body, .gradio-container, .prose, input, button, textarea, span, label { | |
| font-family: 'SimSun', 'STSong', 'Songti SC', serif !important; | |
| } | |
| """ | |
| with gr.Blocks(title="BERT Playground", theme=gr.themes.Soft(), css=custom_css) as demo: | |
| gr.Markdown("# 🤖 BERT 语言模型AI助手") | |
| gr.Markdown("\n\n基于 Google BERT 及其变体模型构建的中文 AI 演示") | |
| with gr.Tab("🧩 完形填空 (Fill Mask)"): | |
| gr.Markdown("输入一句话,用 `[MASK]` 代替你想让 AI 猜的词。") | |
| input_mask = gr.Textbox(label="输入句子", value="我要打王者[MASK]耀。") | |
| btn_mask = gr.Button("开始猜词", variant="primary") | |
| output_mask = gr.Textbox(label="AI 的猜测") | |
| btn_mask.click(fill_mask_ai, inputs=input_mask, outputs=output_mask) | |
| with gr.Tab("📖 阅读理解 (Q&A)"): | |
| gr.Markdown("粘贴一段短文,然后问 AI 一个问题。") | |
| default_context = """A:“小明,你的牙齿真好看!” | |
| B:“哦,那是假的!” | |
| A:“啊?真的假的?” | |
| B:“真的""" | |
| input_context = gr.Textbox(label="文章 (Context)", lines=5, value=default_context) | |
| input_question = gr.Textbox(label="你的问题 (Question)", value="小明的牙齿是真的还是假的?") | |
| btn_qa = gr.Button("寻找答案", variant="primary") | |
| output_qa = gr.Textbox(label="BERT 的对答") | |
| btn_qa.click(reading_comprehension, inputs=[input_context, input_question], outputs=output_qa) | |
| with gr.Tab("❤️ 情感分析 (Sentiment)"): | |
| gr.Markdown("输入一句话,AI 判断语气是积极还是消极。") | |
| input_senti = gr.Textbox(label="输入评价", value="我不想上早八。") | |
| btn_senti = gr.Button("分析情绪", variant="primary") | |
| output_senti = gr.Textbox(label="分析结果") | |
| btn_senti.click(sentiment_analysis, inputs=input_senti, outputs=output_senti) | |
| if __name__ == "__main__": | |
| demo.launch(ssr_mode=False) |