Workspace / app.py
ThieLin's picture
TEST_4
84311ea verified
raw
history blame
2.17 kB
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoTokenizer, AutoModelForQuestionAnswering, pipeline
from sentence_transformers import SentenceTransformer, util
# Modelos
model_name = "deepset/roberta-base-squad2"
qa_pipeline = pipeline("question-answering", model=model_name, tokenizer=model_name)
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
# Modelo para comparação semântica (cosine similarity)
similarity_model = SentenceTransformer("all-MiniLM-L6-v2")
def get_qa_pipeline_answer(question, context):
return qa_pipeline({"question": question, "context": context})["answer"]
def get_zephyr_answer(question, context):
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": f"Context: {context}\n\nQuestion: {question}"},
]
response = client.chat_completion(
messages,
max_tokens=512,
temperature=0.7,
top_p=0.95,
)
return response.choices[0].message.content.strip()
def compare_answers(answer1, answer2):
emb1 = similarity_model.encode(answer1, convert_to_tensor=True)
emb2 = similarity_model.encode(answer2, convert_to_tensor=True)
similarity = util.cos_sim(emb1, emb2).item()
return round(similarity, 3)
def respond(question, context):
answer1 = get_qa_pipeline_answer(question, context)
answer2 = get_zephyr_answer(question, context)
similarity_score = compare_answers(answer1, answer2)
return (
f"📘 Roberta-base-squad2:\n{answer1}\n\n"
f"🧠 Zephyr-7b:\n{answer2}\n\n"
f"🔍 Similaridade Semântica: **{similarity_score}**"
)
# Interface Gradio
with gr.Blocks() as demo:
gr.Markdown("# 🔎 Perguntas com dois modelos\nCompare duas respostas e veja a similaridade.")
with gr.Row():
question = gr.Textbox(label="Pergunta")
context = gr.Textbox(label="Contexto")
submit_btn = gr.Button("Obter Respostas")
output = gr.Textbox(label="Respostas e Similaridade")
submit_btn.click(respond, inputs=[question, context], outputs=output)
if __name__ == "__main__":
demo.launch()