Spaces:
Sleeping
Sleeping
| import torch | |
| import tensorflow as tf | |
| from tf_keras import models, layers | |
| from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering | |
| import gradio as gr | |
| import re | |
| # Check if GPU is available and use it if possible | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| # Load the model and tokenizer | |
| mme_model_name = 'sperkins2116/ConfliBERT-BC-MMEs' | |
| mme_model = AutoModelForSequenceClassification.from_pretrained(mme_model_name).to(device) | |
| mme_tokenizer = AutoTokenizer.from_pretrained(mme_model_name) | |
| # Define the class names for text classification | |
| class_names = ['Negative', 'Positive'] | |
| def handle_error_message(e, default_limit=512): | |
| error_message = str(e) | |
| pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)") | |
| match = pattern.search(error_message) | |
| if match: | |
| number_1, number_2 = match.groups() | |
| return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" | |
| return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>" | |
| def mme_classification(text): | |
| try: | |
| inputs = mme_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device) | |
| with torch.no_grad(): | |
| outputs = mme_model(**inputs) | |
| logits = outputs.logits.squeeze().tolist() | |
| predicted_class = torch.argmax(outputs.logits, dim=1).item() | |
| confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100 | |
| if predicted_class == 1: # Positive class | |
| result = f"<span style='color: green; font-weight: bold;'>Positive: The text contains evidence of a multinational military exercise. (Confidence: {confidence:.2f}%)</span>" | |
| else: # Negative class | |
| result = f"<span style='color: red; font-weight: bold;'>Negative: The text does not contain evidence of a multinational military exercise. (Confidence: {confidence:.2f}%)</span>" | |
| return result | |
| except Exception as e: | |
| return handle_error_message(e) | |
| # Define the Gradio interface | |
| def chatbot(text): | |
| return mme_classification(text) | |
| css = """ | |
| body { | |
| background-color: #f0f8ff; | |
| font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; | |
| color: black; /* Ensure text is visible in dark mode */ | |
| } | |
| h1 { | |
| color: #2e8b57; | |
| text-align: center; | |
| font-size: 2em; | |
| } | |
| h2 { | |
| color: #ff8c00; | |
| text-align: center; | |
| font-size: 1.5em; | |
| } | |
| .gradio-container { | |
| max-width: 100%; | |
| margin: 10px auto; | |
| padding: 10px; | |
| background-color: #ffffff; | |
| border-radius: 10px; | |
| box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); | |
| } | |
| .gr-input, .gr-output { | |
| background-color: #ffffff; | |
| border: 1px solid #ddd; | |
| border-radius: 5px; | |
| padding: 10px; | |
| font-size: 1em; | |
| color: black; /* Ensure text is visible in dark mode */ | |
| } | |
| .gr-title { | |
| font-size: 1.5em; | |
| font-weight: bold; | |
| color: #2e8b57; | |
| margin-bottom: 10px; | |
| text-align: center; | |
| } | |
| .gr-description { | |
| font-size: 1.2em; | |
| color: #ff8c00; | |
| margin-bottom: 10px; | |
| text-align: center; | |
| } | |
| .header { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| padding: 10px; | |
| flex-wrap: wrap; | |
| } | |
| .header-title-center a { | |
| font-size: 4em; /* Increased font size */ | |
| font-weight: bold; /* Made text bold */ | |
| color: darkorange; /* Darker orange color */ | |
| text-align: center; | |
| display: block; | |
| } | |
| .gr-button { | |
| background-color: #ff8c00; | |
| color: white; | |
| border: none; | |
| padding: 10px 20px; | |
| font-size: 1em; | |
| border-radius: 5px; | |
| cursor: pointer; | |
| } | |
| .gr-button:hover { | |
| background-color: #ff4500; | |
| } | |
| .footer { | |
| text-align: center; | |
| margin-top: 10px; | |
| font-size: 0.9em; /* Updated font size */ | |
| color: black; /* Ensure text is visible in dark mode */ | |
| width: 100%; | |
| } | |
| .footer a { | |
| color: #2e8b57; | |
| font-weight: bold; | |
| text-decoration: none; | |
| } | |
| .footer a:hover { | |
| text-decoration: underline; | |
| } | |
| .footer .inline { | |
| display: inline; | |
| color: black; /* Ensure text is visible in dark mode */ | |
| } | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Row(elem_id="header"): | |
| gr.Markdown("<div class='header-title-center'><a href='https://eventdata.utdallas.edu/conflibert/'>ConfliBERT-MME</a></div>", elem_id="header-title-center") | |
| gr.Markdown("<span style='color: black;'>Provide the text for MME Classification.</span>") | |
| text_input = gr.Textbox(lines=5, placeholder="Enter the text here...", label="Text") | |
| output = gr.HTML(label="Output") | |
| submit_button = gr.Button("Submit", elem_id="gr-button") | |
| submit_button.click(fn=chatbot, inputs=text_input, outputs=output) | |
| gr.Markdown("<div class='footer'><a href='https://eventdata.utdallas.edu/'>UTD Event Data</a> | <a href='https://www.utdallas.edu/'>University of Texas at Dallas</a> | <a href='https://www.wvu.edu/'>West Virginia University</a></div>") | |
| gr.Markdown("<div class='footer'><span class='inline'>Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank'>Sultan Alsarra</a> | Finetuned By: Spencer Perkins</span></div>") | |
| demo.launch(share=True) | |