Spaces:
Sleeping
Sleeping
| import os | |
| import openai | |
| import torch | |
| import tensorflow as tf | |
| from transformers import AutoTokenizer, TFAutoModelForQuestionAnswering | |
| import gradio as gr | |
| import re | |
| # Set your OpenAI API key here temporarily for testing | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| # Check if GPU is available and use it if possible | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load the English models and tokenizers | |
| qa_model_name_v1 = 'salsarra/ConfliBERT-QA' | |
| qa_model_v1 = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name_v1) | |
| qa_tokenizer_v1 = AutoTokenizer.from_pretrained(qa_model_name_v1) | |
| bert_model_name_v1 = 'salsarra/BERT-base-cased-SQuAD-v1' | |
| bert_qa_model_v1 = TFAutoModelForQuestionAnswering.from_pretrained(bert_model_name_v1) | |
| bert_qa_tokenizer_v1 = AutoTokenizer.from_pretrained(bert_model_name_v1) | |
| # Load Spanish models and tokenizers | |
| confli_model_spanish_name = 'salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' | |
| confli_model_spanish = TFAutoModelForQuestionAnswering.from_pretrained(confli_model_spanish_name) | |
| confli_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_model_spanish_name) | |
| beto_model_spanish_name = 'salsarra/Beto-Spanish-Cased-NewsQA' | |
| beto_model_spanish = TFAutoModelForQuestionAnswering.from_pretrained(beto_model_spanish_name) | |
| beto_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_model_spanish_name) | |
| # Load the additional Spanish models | |
| confli_sqac_model_spanish = 'salsarra/ConfliBERT-Spanish-Beto-Cased-SQAC' | |
| confli_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(confli_sqac_model_spanish) | |
| confli_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(confli_sqac_model_spanish) | |
| beto_sqac_model_spanish = 'salsarra/Beto-Spanish-Cased-SQAC' | |
| beto_sqac_model_spanish_qa = TFAutoModelForQuestionAnswering.from_pretrained(beto_sqac_model_spanish) | |
| beto_sqac_tokenizer_spanish = AutoTokenizer.from_pretrained(beto_sqac_model_spanish) | |
| # Define error handling to separate input size errors from other issues | |
| def handle_error_message(e, default_limit=512): | |
| error_message = str(e) | |
| pattern = re.compile(r"The size of tensor a \\((\\d+)\\) must match the size of tensor b \\((\\d+)\\)") | |
| match = pattern.search(error_message) | |
| if match: | |
| number_1, number_2 = match.groups() | |
| return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" | |
| pattern_qa = re.compile(r"indices\\[0,(\\d+)\\] = \\d+ is not in \\[0, (\\d+)\\)") | |
| match_qa = pattern_qa.search(error_message) | |
| if match_qa: | |
| number_1, number_2 = match_qa.groups() | |
| return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" | |
| return f"<span style='color: red; font-weight: bold;'>Error: {error_message}</span>" | |
| # Define question_answering_v1 for ConfliBERT English with truncation=True | |
| def question_answering_v1(context, question): | |
| try: | |
| inputs = qa_tokenizer_v1(question, context, return_tensors='tf', truncation=True) | |
| outputs = qa_model_v1(inputs) | |
| answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
| answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
| answer = qa_tokenizer_v1.convert_tokens_to_string( | |
| qa_tokenizer_v1.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]) | |
| ) | |
| return f"<span style='font-weight: bold;'>{answer}</span>" | |
| except Exception as e: | |
| return handle_error_message(e) | |
| # Define bert_question_answering_v1 for BERT English with truncation=True | |
| def bert_question_answering_v1(context, question): | |
| try: | |
| inputs = bert_qa_tokenizer_v1(question, context, return_tensors='tf', truncation=True) | |
| outputs = bert_qa_model_v1(inputs) | |
| answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
| answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
| answer = bert_qa_tokenizer_v1.convert_tokens_to_string( | |
| bert_qa_tokenizer_v1.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]) | |
| ) | |
| return f"<span style='font-weight: bold;'>{answer}</span>" | |
| except Exception as e: | |
| return handle_error_message(e) | |
| # Define question_answering_spanish for ConfliBERT-Spanish-Beto-Cased-NewsQA | |
| def question_answering_spanish(context, question): | |
| try: | |
| inputs = confli_tokenizer_spanish.encode_plus(question, context, return_tensors='tf', truncation=True) | |
| outputs = confli_model_spanish(inputs) | |
| answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
| answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
| answer = confli_tokenizer_spanish.convert_tokens_to_string( | |
| confli_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]) | |
| ) | |
| return f"<span style='font-weight: bold;'>{answer}</span>" | |
| except Exception as e: | |
| return handle_error_message(e) | |
| # Define beto_question_answering_spanish for Beto-Spanish-Cased-NewsQA | |
| def beto_question_answering_spanish(context, question): | |
| try: | |
| inputs = beto_tokenizer_spanish.encode_plus(question, context, return_tensors='tf', truncation=True) | |
| outputs = beto_model_spanish(inputs) | |
| answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
| answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
| answer = beto_tokenizer_spanish.convert_tokens_to_string( | |
| beto_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]) | |
| ) | |
| return f"<span style='font-weight: bold;'>{answer}</span>" | |
| except Exception as e: | |
| return handle_error_message(e) | |
| # Define confli_sqac_question_answering_spanish for ConfliBERT-Spanish-Beto-Cased-SQAC | |
| def confli_sqac_question_answering_spanish(context, question): | |
| inputs = confli_sqac_tokenizer_spanish.encode_plus(question, context, return_tensors="tf", truncation=True) | |
| outputs = confli_sqac_model_spanish_qa(inputs) | |
| answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
| answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
| answer = confli_sqac_tokenizer_spanish.convert_tokens_to_string( | |
| confli_sqac_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]) | |
| ) | |
| return f"<span style='font-weight: bold;'>{answer}</span>" | |
| # Define beto_sqac_question_answering_spanish for Beto-Spanish-Cased-SQAC | |
| def beto_sqac_question_answering_spanish(context, question): | |
| inputs = beto_sqac_tokenizer_spanish.encode_plus(question, context, return_tensors="tf", truncation=True) | |
| outputs = beto_sqac_model_spanish_qa(inputs) | |
| answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] | |
| answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 | |
| answer = beto_sqac_tokenizer_spanish.convert_tokens_to_string( | |
| beto_sqac_tokenizer_spanish.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end]) | |
| ) | |
| return f"<span style='font-weight: bold;'>{answer}</span>" | |
| # Define a function to get ChatGPT's answer in English using the latest OpenAI API | |
| def chatgpt_question_answering(context, question): | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": f"Context: {context}\nQuestion: {question}\nAnswer:"} | |
| ] | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=messages, | |
| max_tokens=150 | |
| ) | |
| return response['choices'][0]['message']['content'].strip() | |
| # Define a function to get ChatGPT's answer in Spanish using the latest OpenAI API | |
| def chatgpt_question_answering_spanish(context, question): | |
| messages = [ | |
| {"role": "system", "content": "You are a helpful assistant that responds in Spanish."}, | |
| {"role": "user", "content": f"Contexto: {context}\nPregunta: {question}\nRespuesta:"} | |
| ] | |
| response = openai.ChatCompletion.create( | |
| model="gpt-3.5-turbo", | |
| messages=messages, | |
| max_tokens=150 | |
| ) | |
| return response['choices'][0]['message']['content'].strip() | |
| # Main comparison function with language selection | |
| def compare_question_answering(language, context, question): | |
| if language == "English": | |
| confli_answer_v1 = question_answering_v1(context, question) | |
| bert_answer_v1 = bert_question_answering_v1(context, question) | |
| chatgpt_answer = chatgpt_question_answering(context, question) | |
| return f""" | |
| <div> | |
| <h2 style='color: #2e8b57; font-weight: bold;'>Answers:</h2> | |
| </div><br> | |
| <div> | |
| <strong style='color: green; font-weight: bold;'>ConfliBERT-cont-cased-SQuAD-v1:</strong><br><span style='font-weight: bold;'>{confli_answer_v1}</span></div><br> | |
| <div> | |
| <strong style='color: orange; font-weight: bold;'>BERT-base-cased-SQuAD-v1:</strong><br><span style='font-weight: bold;'>{bert_answer_v1}</span> | |
| </div><br> | |
| <div> | |
| <strong style='color: #74AA9C; font-weight: bold;'>ChatGPT:</strong><br><span style='font-weight: bold;'>{chatgpt_answer}</span> | |
| </div><br> | |
| <div> | |
| <strong>Model Information:</strong><br> | |
| <a href='https://huggingface.co/salsarra/ConfliBERT-QA' target='_blank'>ConfliBERT-cont-cased-SQuAD-v1</a><br> | |
| <a href='https://huggingface.co/salsarra/BERT-base-cased-SQuAD-v1' target='_blank'>BERT-base-cased-SQuAD-v1</a><br> | |
| <a href='https://platform.openai.com/docs/models/gpt-3-5' target='_blank'>ChatGPT (GPT-3.5 Turbo)</a><br></p> | |
| </div> | |
| """ | |
| elif language == "Spanish": | |
| confli_answer_spanish = question_answering_spanish(context, question) | |
| beto_answer_spanish = beto_question_answering_spanish(context, question) | |
| confli_sqac_answer_spanish = confli_sqac_question_answering_spanish(context, question) | |
| beto_sqac_answer_spanish = beto_sqac_question_answering_spanish(context, question) | |
| chatgpt_answer_spanish = chatgpt_question_answering_spanish(context, question) | |
| return f""" | |
| <div> | |
| <h2 style='color: #2e8b57; font-weight: bold;'>Answers:</h2> | |
| </div><br> | |
| <div> | |
| <strong style='color: green; font-weight: bold;'>ConfliBERT-Spanish-Beto-Cased-NewsQA:</strong><br><span style='font-weight: bold;'>{confli_answer_spanish}</span></div><br> | |
| <div> | |
| <strong style='color: orange; font-weight: bold;'>Beto-Spanish-Cased-NewsQA:</strong><br><span style='font-weight: bold;'>{beto_answer_spanish}</span> | |
| </div><br> | |
| <div> | |
| <strong style='color: green; font-weight: bold;'>ConfliBERT-Spanish-Beto-Cased-SQAC:</strong><br><span style='font-weight: bold;'>{confli_sqac_answer_spanish}</span> | |
| </div><br> | |
| <div> | |
| <strong style='color: orange; font-weight: bold;'>Beto-Spanish-Cased-SQAC:</strong><br><span style='font-weight: bold;'>{beto_sqac_answer_spanish}</span> | |
| </div><br> | |
| <div> | |
| <strong style='color: #74AA9C; font-weight: bold;'>ChatGPT:</strong><br><span style='font-weight: bold;'>{chatgpt_answer_spanish}</span> | |
| </div><br> | |
| <div> | |
| <strong>Model Information:</strong><br> | |
| <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-NewsQA' target='_blank'>ConfliBERT-Spanish-Beto-Cased-NewsQA</a><br> | |
| <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-NewsQA' target='_blank'>Beto-Spanish-Cased-NewsQA</a><br> | |
| <a href='https://huggingface.co/salsarra/ConfliBERT-Spanish-Beto-Cased-SQAC' target='_blank'>ConfliBERT-Spanish-Beto-Cased-SQAC</a><br> | |
| <a href='https://huggingface.co/salsarra/Beto-Spanish-Cased-SQAC' target='_blank'>Beto-Spanish-Cased-SQAC</a><br> | |
| <a href='https://platform.openai.com/docs/models/gpt-3-5' target='_blank'>ChatGPT (GPT-3.5 Turbo)</a><br></p> | |
| </div> | |
| """ | |
| # Gradio interface setup | |
| with gr.Blocks(css=""" | |
| body { | |
| background-color: #f0f8ff; | |
| font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; | |
| } | |
| h1, h1 a { | |
| color: #2e8b57; | |
| text-align: center; | |
| font-size: 2em; | |
| text-decoration: none; | |
| } | |
| h1 a:hover { | |
| color: #ff8c00; | |
| } | |
| h2 { | |
| color: #ff8c00; | |
| text-align: center; | |
| font-size: 1.5em; | |
| } | |
| """) as demo: | |
| gr.Markdown("# [ConfliBERT-QA](https://eventdata.utdallas.edu/conflibert/)", elem_id="title") | |
| gr.Markdown("Compare answers between ConfliBERT, BERT, and ChatGPT for English, and ConfliBERT, BETO, ConfliBERT-SQAC, Beto-SQAC, and ChatGPT for Spanish.") | |
| language = gr.Dropdown(choices=["English", "Spanish"], label="Select Language") | |
| context = gr.Textbox(lines=5, placeholder="Enter the context here...", label="Context") | |
| question = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question") | |
| output = gr.HTML(label="Output") | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| submit_btn = gr.Button("Submit") | |
| submit_btn.click(fn=compare_question_answering, inputs=[language, context, question], outputs=output) | |
| clear_btn.click(fn=lambda: ("", "", "", ""), inputs=[], outputs=[language, context, question, output]) | |
| gr.Markdown(""" | |
| <div style="text-align: center; margin-top: 20px;"> | |
| Built by: <a href="https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/" target="_blank">Sultan Alsarra</a> | |
| </div> | |
| """) | |
| demo.launch(share=True) | |