| | import gradio as gr |
| | import torch |
| | from transformers import AutoTokenizer, AutoModelForQuestionAnswering |
| |
|
| | MODEL_NAME = "Salommee/bert-squad-qa" |
| |
|
| | print("Loading model...") |
| | try: |
| | tokenizer = AutoTokenizer.from_pretrained( |
| | MODEL_NAME, |
| | trust_remote_code=True |
| | ) |
| | model = AutoModelForQuestionAnswering.from_pretrained( |
| | MODEL_NAME, |
| | trust_remote_code=True |
| | ) |
| | print("Model loaded successfully!") |
| | except Exception as e: |
| | print(f"Error loading model: {e}") |
| | raise |
| |
|
| | |
| | def answer_question(context, question): |
| | if not context.strip(): |
| | return "β Provide context.", "N/A" |
| | if not question.strip(): |
| | return "β Provide question.", "N/A" |
| | |
| | inputs = tokenizer( |
| | question, |
| | context, |
| | truncation="only_second", |
| | max_length=384, |
| | return_tensors="pt", |
| | padding=True |
| | ) |
| | |
| | with torch.no_grad(): |
| | outputs = model(**inputs) |
| | |
| | start_idx = torch.argmax(outputs.start_logits) |
| | end_idx = torch.argmax(outputs.end_logits) |
| | |
| | start_score = torch.softmax(outputs.start_logits, dim=1)[0][start_idx].item() |
| | end_score = torch.softmax(outputs.end_logits, dim=1)[0][end_idx].item() |
| | confidence = start_score * end_score |
| | |
| | if start_idx > end_idx or start_idx==0 or end_idx==0: |
| | return "β Answer not found. Try rephrasing your question.", f"{confidence:.2%}" |
| | |
| | answer = tokenizer.decode(inputs.input_ids[0][start_idx:end_idx+1], skip_special_tokens=True) |
| | emoji = "π’" if confidence>0.8 else "π‘" if confidence>0.5 else "π΄" |
| | return f"β
{answer}", f"{emoji} {confidence:.2%}" |
| |
|
| | |
| | examples = [ |
| | ["Paris is the capital of France.", "What is the capital of France?"], |
| | ["Eiffel Tower built 1887-1889.", "When was the Eiffel Tower built?"], |
| | ["Machine learning automates model building.", "What is machine learning?"] |
| | ] |
| |
|
| | |
| | with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| | gr.Markdown("# π€ BERT Question Answering") |
| | |
| | with gr.Row(): |
| | with gr.Column(scale=2): |
| | context_input = gr.Textbox(label="π Context", lines=6, placeholder="Enter context here...") |
| | question_input = gr.Textbox(label="β Question", lines=2, placeholder="Ask your question...") |
| | submit_btn = gr.Button("π Get Answer") |
| | with gr.Column(scale=1): |
| | answer_output = gr.Textbox(label="π‘ Answer", lines=2) |
| | confidence_output = gr.Textbox(label="π Confidence", lines=1) |
| | |
| | gr.Examples( |
| | examples, |
| | inputs=[context_input, question_input], |
| | outputs=[answer_output, confidence_output], |
| | fn=answer_question |
| | ) |
| | |
| | submit_btn.click( |
| | fn=answer_question, |
| | inputs=[context_input, question_input], |
| | outputs=[answer_output, confidence_output] |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | demo.launch(share=True) |