Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification | |
| import logging | |
| # Set up logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialize models | |
| question_answerer = None | |
| class_pipe = None | |
| # Define topic labels (replace with your actual labels) | |
| TOPIC_LABELS = { | |
| 0: "Science & Technology", | |
| 1: "Health & Medicine", | |
| 2: "Business & Finance", | |
| 3: "Entertainment", | |
| 4: "Politics", | |
| 5: "Sports", | |
| 6: "Education", | |
| 7: "Culture & Society", | |
| 8: "Computers & Internet", | |
| 9: "Other" | |
| } | |
| def load_models(): | |
| """Load both models with error handling""" | |
| global question_answerer, class_pipe | |
| try: | |
| logger.info("Loading QA model...") | |
| question_answerer = pipeline( | |
| "question-answering", | |
| model='distilbert-base-cased-distilled-squad' | |
| ) | |
| logger.info("Loading classification model...") | |
| # Replace with your actual fine-tuned model path | |
| class_model = "Lindalynn/Yahoo_Question_Classifier" | |
| tokenizer = AutoTokenizer.from_pretrained(class_model) | |
| model = AutoModelForSequenceClassification.from_pretrained(class_model) | |
| class_pipe = pipeline( | |
| "text-classification", | |
| model=model, | |
| tokenizer=tokenizer | |
| ) | |
| return True | |
| except Exception as e: | |
| logger.error(f"Model loading failed: {str(e)}") | |
| return False | |
| def process_question(question, context): | |
| """Process question through both models""" | |
| if not load_models(): | |
| return "Models failed to load", "Error" | |
| if not question.strip(): | |
| return "Please enter a question", "Invalid input" | |
| try: | |
| # QA Model - Using the exact deployment style you provided | |
| qa_result = question_answerer( | |
| question=question, | |
| context=context | |
| ) | |
| answer = qa_result['answer'] | |
| score = round(qa_result['score'], 4) | |
| # Classification Model | |
| class_result = class_pipe(question) | |
| label = class_result[0]['label'] | |
| topic = TOPIC_LABELS.get(int(label.split('_')[-1]), "General") | |
| # Format output with score information as in your example | |
| formatted_answer = f"Answer: '{answer}', score: {score}, topic: {topic}" | |
| return formatted_answer, topic | |
| except Exception as e: | |
| logger.error(f"Processing error: {str(e)}") | |
| return "Error processing question", "Error" | |
| # Create Gradio interface | |
| with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🤖 QA with Classification | |
| *Question Answering with distilbert-base-cased-distilled-squad* | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| context_input = gr.Textbox( | |
| label="Context", | |
| placeholder="Enter the context text here...", | |
| lines=3 | |
| ) | |
| question_input = gr.Textbox( | |
| label="Question", | |
| placeholder="Enter your question here...", | |
| lines=2 | |
| ) | |
| submit_btn = gr.Button("Get Answer", variant="primary") | |
| with gr.Column(): | |
| answer_output = gr.Textbox( | |
| label="QA Result", | |
| interactive=False | |
| ) | |
| topic_output = gr.Textbox( | |
| label="Detected Topic", | |
| interactive=False | |
| ) | |
| # Example context and questions | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "Alice is sitting on the bench. Bob is sitting next to her.", | |
| "Who is sitting on the bench?" | |
| ], | |
| [ | |
| "The company was founded in 2010 by Elon Musk. The current CEO is Tim Cook.", | |
| "Who is the current CEO?" | |
| ] | |
| ], | |
| inputs=[context_input, question_input] | |
| ) | |
| submit_btn.click( | |
| fn=process_question, | |
| inputs=[question_input, context_input], | |
| outputs=[answer_output, topic_output] | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| demo.launch() |