Workspace / app.py
ThieLin's picture
TEST_7
8742642 verified
raw
history blame
2.66 kB
import gradio as gr
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
class ModelComparator:
def __init__(self):
# Modelo de QA (mais rápido e leve)
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
# Modelo de geração de texto simples
self.text_gen_pipeline = pipeline("text-generation", model="gpt2", max_new_tokens=50)
# Modelo para embeddings e similaridade
self.sim_model = SentenceTransformer("all-MiniLM-L6-v2")
def get_qa_answer(self, question, context=None):
# Se não passar contexto, responde "não sei"
if context is None:
return "No context provided for QA model."
try:
result = self.qa_pipeline(question=question, context=context)
return result['answer']
except Exception as e:
return f"Error in QA pipeline: {e}"
def get_text_gen_answer(self, prompt):
try:
generated = self.text_gen_pipeline(prompt)[0]['generated_text']
# O GPT2 gera o texto incluindo o prompt, vamos remover o prompt para deixar só resposta
answer = generated[len(prompt):].strip()
return answer if answer else generated.strip()
except Exception as e:
return f"Error in text generation pipeline: {e}"
def compare_answers(self, answer1, answer2):
emb1 = self.sim_model.encode(answer1, convert_to_tensor=True)
emb2 = self.sim_model.encode(answer2, convert_to_tensor=True)
similarity = util.cos_sim(emb1, emb2).item()
return round(similarity, 3)
def respond(self, question, context):
qa_answer = self.get_qa_answer(question, context)
gen_answer = self.get_text_gen_answer(question)
similarity = self.compare_answers(qa_answer, gen_answer)
return (f"Model QA answer:\n{qa_answer}\n\n"
f"Model GPT-2 generated answer:\n{gen_answer}\n\n"
f"Semantic similarity score: {similarity}")
# Interface Gradio
model_comparator = ModelComparator()
with gr.Blocks() as demo:
gr.Markdown("## Comparador de respostas entre dois modelos locais (CPU)")
question_input = gr.Textbox(label="Pergunta")
context_input = gr.Textbox(label="Contexto para o modelo de QA (opcional)", lines=5)
output = gr.Textbox(label="Respostas e Similaridade", lines=15)
btn = gr.Button("Comparar")
btn.click(
fn=model_comparator.respond,
inputs=[question_input, context_input],
outputs=output
)
if __name__ == "__main__":
demo.launch()