import gradio as gr from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification import logging # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize models question_answerer = None class_pipe = None # Define topic labels (replace with your actual labels) TOPIC_LABELS = { 0: "Science & Technology", 1: "Health & Medicine", 2: "Business & Finance", 3: "Entertainment", 4: "Politics", 5: "Sports", 6: "Education", 7: "Culture & Society", 8: "Computers & Internet", 9: "Other" } def load_models(): """Load both models with error handling""" global question_answerer, class_pipe try: logger.info("Loading QA model...") question_answerer = pipeline( "question-answering", model='distilbert-base-cased-distilled-squad' ) logger.info("Loading classification model...") # Replace with your actual fine-tuned model path class_model = "Lindalynn/Yahoo_Question_Classifier" tokenizer = AutoTokenizer.from_pretrained(class_model) model = AutoModelForSequenceClassification.from_pretrained(class_model) class_pipe = pipeline( "text-classification", model=model, tokenizer=tokenizer ) return True except Exception as e: logger.error(f"Model loading failed: {str(e)}") return False def process_question(question, context): """Process question through both models""" if not load_models(): return "Models failed to load", "Error" if not question.strip(): return "Please enter a question", "Invalid input" try: # QA Model - Using the exact deployment style you provided qa_result = question_answerer( question=question, context=context ) answer = qa_result['answer'] score = round(qa_result['score'], 4) # Classification Model class_result = class_pipe(question) label = class_result[0]['label'] topic = TOPIC_LABELS.get(int(label.split('_')[-1]), "General") # Format output with score information as in your example formatted_answer = f"Answer: '{answer}', score: {score}, topic: {topic}" return formatted_answer, topic except Exception as e: logger.error(f"Processing error: {str(e)}") return "Error processing question", "Error" # Create Gradio interface with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🤖 QA with Classification *Question Answering with distilbert-base-cased-distilled-squad* """) with gr.Row(): with gr.Column(): context_input = gr.Textbox( label="Context", placeholder="Enter the context text here...", lines=3 ) question_input = gr.Textbox( label="Question", placeholder="Enter your question here...", lines=2 ) submit_btn = gr.Button("Get Answer", variant="primary") with gr.Column(): answer_output = gr.Textbox( label="QA Result", interactive=False ) topic_output = gr.Textbox( label="Detected Topic", interactive=False ) # Example context and questions gr.Examples( examples=[ [ "Alice is sitting on the bench. Bob is sitting next to her.", "Who is sitting on the bench?" ], [ "The company was founded in 2010 by Elon Musk. The current CEO is Tim Cook.", "Who is the current CEO?" ] ], inputs=[context_input, question_input] ) submit_btn.click( fn=process_question, inputs=[question_input, context_input], outputs=[answer_output, topic_output] ) # Launch the interface if __name__ == "__main__": demo.launch()