Workspace / app.py
ThieLin's picture
TEST_8
767e590 verified
raw
history blame
2.35 kB
import gradio as gr
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
class ModelComparator:
def __init__(self):
self.qa_pipeline = pipeline("question-answering", model="distilbert-base-uncased-distilled-squad")
self.text_gen_pipeline = pipeline("text-generation", model="gpt2", max_new_tokens=20) # menor geração
self.sim_model = SentenceTransformer("all-MiniLM-L6-v2")
def get_qa_answer(self, question, context=None):
if not context:
return "No context provided for QA model."
try:
result = self.qa_pipeline(question=question, context=context)
return result['answer']
except Exception as e:
return f"Error in QA pipeline: {e}"
def get_text_gen_answer(self, prompt):
try:
generated = self.text_gen_pipeline(prompt)[0]['generated_text']
answer = generated[len(prompt):].strip()
return answer if answer else generated.strip()
except Exception as e:
return f"Error in text generation pipeline: {e}"
def compare_answers(self, answer1, answer2):
emb1 = self.sim_model.encode(answer1, convert_to_tensor=True)
emb2 = self.sim_model.encode(answer2, convert_to_tensor=True)
similarity = util.cos_sim(emb1, emb2).item()
return round(similarity, 3)
def respond(self, question, context):
qa_answer = self.get_qa_answer(question, context)
gen_answer = self.get_text_gen_answer(question)
similarity = self.compare_answers(qa_answer, gen_answer)
return (f"Model QA answer:\n{qa_answer}\n\n"
f"Model GPT-2 generated answer:\n{gen_answer}\n\n"
f"Semantic similarity score: {similarity}")
model_comparator = ModelComparator()
with gr.Blocks() as demo:
gr.Markdown("## Comparador rápido para Hugging Face Spaces")
question_input = gr.Textbox(label="Pergunta")
context_input = gr.Textbox(label="Contexto para o modelo de QA (opcional)", lines=3)
output = gr.Textbox(label="Respostas e Similaridade", lines=15)
btn = gr.Button("Comparar")
btn.click(
fn=model_comparator.respond,
inputs=[question_input, context_input],
outputs=output
)
if __name__ == "__main__":
demo.launch()