Lindalynn's picture
Update app.py
b6110ce verified
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
import logging
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize models
question_answerer = None
class_pipe = None
# Define topic labels (replace with your actual labels)
TOPIC_LABELS = {
0: "Science & Technology",
1: "Health & Medicine",
2: "Business & Finance",
3: "Entertainment",
4: "Politics",
5: "Sports",
6: "Education",
7: "Culture & Society",
8: "Computers & Internet",
9: "Other"
}
def load_models():
"""Load both models with error handling"""
global question_answerer, class_pipe
try:
logger.info("Loading QA model...")
question_answerer = pipeline(
"question-answering",
model='distilbert-base-cased-distilled-squad'
)
logger.info("Loading classification model...")
# Replace with your actual fine-tuned model path
class_model = "Lindalynn/Yahoo_Question_Classifier"
tokenizer = AutoTokenizer.from_pretrained(class_model)
model = AutoModelForSequenceClassification.from_pretrained(class_model)
class_pipe = pipeline(
"text-classification",
model=model,
tokenizer=tokenizer
)
return True
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
return False
def process_question(question, context):
"""Process question through both models"""
if not load_models():
return "Models failed to load", "Error"
if not question.strip():
return "Please enter a question", "Invalid input"
try:
# QA Model - Using the exact deployment style you provided
qa_result = question_answerer(
question=question,
context=context
)
answer = qa_result['answer']
score = round(qa_result['score'], 4)
# Classification Model
class_result = class_pipe(question)
label = class_result[0]['label']
topic = TOPIC_LABELS.get(int(label.split('_')[-1]), "General")
# Format output with score information as in your example
formatted_answer = f"Answer: '{answer}', score: {score}, topic: {topic}"
return formatted_answer, topic
except Exception as e:
logger.error(f"Processing error: {str(e)}")
return "Error processing question", "Error"
# Create Gradio interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# 🤖 QA with Classification
*Question Answering with distilbert-base-cased-distilled-squad*
""")
with gr.Row():
with gr.Column():
context_input = gr.Textbox(
label="Context",
placeholder="Enter the context text here...",
lines=3
)
question_input = gr.Textbox(
label="Question",
placeholder="Enter your question here...",
lines=2
)
submit_btn = gr.Button("Get Answer", variant="primary")
with gr.Column():
answer_output = gr.Textbox(
label="QA Result",
interactive=False
)
topic_output = gr.Textbox(
label="Detected Topic",
interactive=False
)
# Example context and questions
gr.Examples(
examples=[
[
"Alice is sitting on the bench. Bob is sitting next to her.",
"Who is sitting on the bench?"
],
[
"The company was founded in 2010 by Elon Musk. The current CEO is Tim Cook.",
"Who is the current CEO?"
]
],
inputs=[context_input, question_input]
)
submit_btn.click(
fn=process_question,
inputs=[question_input, context_input],
outputs=[answer_output, topic_output]
)
# Launch the interface
if __name__ == "__main__":
demo.launch()