Spaces:
Sleeping
Sleeping
| import warnings | |
| warnings.filterwarnings('ignore') | |
| # Import necessary libraries | |
| import gradio as gr | |
| import torch | |
| from transformers import ( | |
| BertTokenizerFast, | |
| BertForQuestionAnswering, | |
| AutoTokenizer, | |
| BartForQuestionAnswering, | |
| DistilBertTokenizerFast, | |
| DistilBertForQuestionAnswering | |
| ) | |
| import gc | |
| # Create a context store | |
| context_store = [] | |
| selected_model = None # To track the selected model | |
| # Define models and tokenizers | |
| def load_bert_model_and_tokenizer(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_save_path = "LivisLiquoro/BERT_Model_Squad1.1" | |
| model = BertForQuestionAnswering.from_pretrained(model_save_path) | |
| tokenizer = BertTokenizerFast.from_pretrained(model_save_path) | |
| model.eval().to(device) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return tokenizer, model, device | |
| def load_bart_model_and_tokenizer(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = BartForQuestionAnswering.from_pretrained("valhalla/bart-large-finetuned-squadv1") | |
| tokenizer = AutoTokenizer.from_pretrained("valhalla/bart-large-finetuned-squadv1") | |
| model.eval().to(device) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return tokenizer, model, device | |
| def load_distilbert_model_and_tokenizer(): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model_save_path = "LivisLiquoro/DistilBert_model_squad1.1" | |
| model = DistilBertForQuestionAnswering.from_pretrained(model_save_path) | |
| tokenizer = DistilBertTokenizerFast.from_pretrained(model_save_path) | |
| model.eval().to(device) | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| return tokenizer, model, device | |
| def clean_answer(tokens): | |
| """ | |
| Clean the tokens by removing special tokens like [SEP], [CLS], and fixing token fragments. | |
| """ | |
| cleaned_tokens = [] | |
| for token in tokens: | |
| if token in ['[SEP]', '[CLS]']: | |
| continue # Skip special tokens | |
| token = token.replace('##', '') # Remove '##' prefix | |
| if token: # Only add non-empty tokens | |
| cleaned_tokens.append(token) | |
| return tokenizer.convert_tokens_to_string(cleaned_tokens).strip() or None | |
| def generate_answer(context, question): | |
| max_attempts = 10 # Set maximum attempts for generating answers | |
| attempts = 0 | |
| best_answer = None | |
| # Adjusting the context chunking method | |
| max_length = 512 | |
| chunks = [context[i:i + max_length] for i in range(0, len(context), max_length)] | |
| while attempts < max_attempts: | |
| attempts += 1 | |
| for chunk in chunks: | |
| inputs = tokenizer(chunk, question, return_tensors='pt', truncation=True, max_length=max_length).to(device) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| answer_start = torch.argmax(outputs.start_logits) | |
| answer_end = torch.argmax(outputs.end_logits) + 1 | |
| if answer_start < answer_end: | |
| answer = clean_answer(tokenizer.convert_ids_to_tokens(inputs['input_ids'][0][answer_start:answer_end])) | |
| # Validate answer and ensure it's direct | |
| if answer and answer.lower() != "no valid answer found": | |
| best_answer = answer.capitalize() | |
| break # Exit the chunk loop if a valid answer is found | |
| if best_answer: # If an answer is found, no need to keep trying | |
| break | |
| if best_answer: # If a valid answer was found, exit the attempts loop | |
| break | |
| if best_answer: | |
| return best_answer | |
| else: | |
| return "❌ No valid answer found." | |
| # Define the Gradio interface with light theme and organized layout | |
| def chatbot_interface(): | |
| with gr.Blocks() as demo: | |
| # Custom CSS for light theme and layout | |
| gr.Markdown(""" | |
| <style> | |
| body { background-color: #f9f9f9; } | |
| .chatbot-container { background-color: #ffffff; border-radius: 10px; padding: 20px; color: #333; font-family: Arial, sans-serif; } | |
| .gr-button { background-color: #4CAF50; color: white; border: none; border-radius: 5px; padding: 10px 20px; font-size: 14px; cursor: pointer; } | |
| .gr-button:hover { background-color: #45a049; } | |
| .gr-textbox { background-color: #ffffff; color: #333; border-radius: 5px; border: 1px solid #ddd; padding: 10px; } | |
| .gr-chatbot { background-color: #e6e6e6; border-radius: 10px; padding: 15px; color: #333; } | |
| .footer { text-align: right; font-size: 12px; color: #777; font-style: italic; } | |
| .note { text-align: right; font-size: 10px; color: #777; font-style: italic; position: absolute; bottom: 10px; right: 10px; } | |
| </style> | |
| """) | |
| # Header | |
| gr.Markdown("<h1 style='text-align: center; color: #4CAF50;'>EDITH: Multi-Model Question Answering Platform</h1>") | |
| gr.Markdown("<p style='text-align: center; color: #777;'>Switch between BERT, BART, and DistilBERT models and ask questions based on the context.</p>") | |
| context_state = gr.State() | |
| model_choice_state = gr.State(value="BERT") # Default model is BERT | |
| with gr.Row(): | |
| with gr.Column(scale=11): # Left panel for chatbot and question input (45%) | |
| chatbot = gr.Chatbot(label="Chatbot") | |
| question_input = gr.Textbox(label="Ask a Question", placeholder="Enter your question here...", lines=1) | |
| submit_btn = gr.Button("Submit Question") | |
| with gr.Column(scale=9): # Right panel for setting context and instructions (55%) | |
| context_input = gr.Textbox(label="Set Context", placeholder="Enter the context here...", lines=4) | |
| set_context_btn = gr.Button("Set Context") | |
| clear_context_btn = gr.Button("Clear Context") | |
| # Model selection buttons | |
| model_selection = gr.Radio(choices=["BERT", "BART", "DistilBERT"], label="Select Model", value="BERT") | |
| status_message = gr.Markdown("") | |
| gr.Markdown("<strong>Instructions:</strong><br>1. Set a context.<br>2. Select the model (BERT, BART, or DistilBERT).<br>3. Ask questions based on the context.<br><br><strong>Note:</strong> <span class='note'>The BART model is pre-trained from Hugging Face. Credits to Hugging Face and the person who fine-tuned this model ('valhalla/bart-large-finetuned-squadv1')</span>") | |
| footer = gr.Markdown("<div class='footer'>Prepared by: Team EDITH</div>") | |
| def set_context(context): | |
| if not context.strip(): | |
| return gr.update(), "Please enter a valid context.", None | |
| return gr.update(visible=False), "Context has been set. You can now ask questions.", context | |
| def clear_context(): | |
| return gr.update(visible=True), "Context has been cleared. Please set a new context.", None | |
| def handle_question(question, history, context, model_choice): | |
| global tokenizer, model, device | |
| if not context: | |
| return history, "Please set the context before asking questions." | |
| if not question.strip(): | |
| return history, "Please enter a valid question." | |
| # Load the selected model and tokenizer | |
| if model_choice == "BERT": | |
| tokenizer, model, device = load_bert_model_and_tokenizer() | |
| model_name = "BERT" | |
| elif model_choice == "BART": | |
| tokenizer, model, device = load_bart_model_and_tokenizer() | |
| model_name = "BART" | |
| elif model_choice == "DistilBERT": | |
| tokenizer, model, device = load_distilbert_model_and_tokenizer() | |
| model_name = "DistilBERT" | |
| answer = generate_answer(context, question) | |
| history = history + [[f"👤: {question}", f"🤖 ({model_name}): {answer}"]] # Show the selected model with the answer | |
| return history, "" | |
| set_context_btn.click(set_context, inputs=context_input, outputs=[context_input, status_message, context_state]) | |
| clear_context_btn.click(clear_context, inputs=None, outputs=[context_input, status_message, context_state]) | |
| submit_btn.click(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input]) | |
| # Enable "Enter" key to trigger the "Submit" button | |
| question_input.submit(handle_question, inputs=[question_input, chatbot, context_state, model_selection], outputs=[chatbot, question_input]) | |
| return demo | |
| # Run the Gradio interface | |
| interface = chatbot_interface() | |
| interface.launch(share=True) | |