Spaces:
Sleeping
Sleeping
File size: 2,947 Bytes
51da700 79f0df1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
import gradio as gr
import torch
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import pandas as pd
import plotly.express as px
import os
# --- 1. 模型加载 ---
# 负责同学: [填写负责这个模型的同学姓名,例如:张三]
# 注意:QuantFactory/Apollo2-7B-GGUF 模型通常不直接兼容 pipeline("text-generation", ...)
# 除非有额外的llama.cpp或特定的transformers加载配置。
# 为了演示和确保运行流畅,这里使用 gpt2-large 作为替代。
try:
model1_name = "gpt2-large" # 替代 QuantFactory/Apollo2-7B-GGUF 以确保兼容性
generator1 = pipeline("text-generation", model=model1_name, device=0 if torch.cuda.is_available() else -1)
print(f"✅ 模型 1 (文本生成: {model1_name}) 加载成功!")
except Exception as e:
print(f"❌ 模型 1 (文本生成: {model1_name}) 加载失败: {e}")
generator1 = None
# 负责同学: [填写负责这个模型的同学姓名,例如:李四]
# deepset/roberta-base-squad2 是一个问答模型,需要 context
try:
model2_name = "deepset/roberta-base-squad2"
qa_model = pipeline("question-answering", model=model2_name, device=0 if torch.cuda.is_available() else -1)
print(f"✅ 模型 2 (问答: {model2_name}) 加载成功!")
except Exception as e:
print(f"❌ 模型 2 (问答: {model2_name}) 加载失败: {e}")
qa_model = None
# --- 2. 推理函数 ---
# 这个函数现在接受一个问题/提示词和一个上下文
def get_model_outputs(question_or_prompt, context, max_length=100):
output_text_gen = "文本生成模型未加载或生成失败。"
output_qa = "问答模型未加载或生成失败。"
# 模型 1: 文本生成
if generator1:
try:
# 文本生成模型将问题和上下文作为其prompt的一部分
full_prompt_for_gen = f"{question_or_prompt}\nContext: {context}" if context else question_or_prompt
gen_result = generator1(full_prompt_for_gen, max_new_tokens=max_length, num_return_sequences=1, truncation=True)
output_text_gen = gen_result[0]['generated_text']
# 清理:移除输入部分,只保留生成内容
if output_text_gen.startswith(full_prompt_for_gen):
output_text_gen = output_text_gen[len(full_prompt_for_gen):].strip()
except Exception as e:
output_text_gen = f"文本生成模型 ({model1_name}) 错误: {e}"
# 模型 2: 问答
if qa_model and context: # 问答模型必须有上下文
try:
qa_result = qa_model(question=question_or_prompt, context=context)
output_qa = qa_result['answer']
except Exception as e:
output_qa = f"问答模型 ({model2_name}) 错误: {e}"
elif qa_model and not context:
output_qa = "问答模型需要提供上下文才能回答问题。"
return output_text_gen, output_qa |