Spaces:
Running
Running
| import os | |
| import time | |
| from datetime import datetime | |
| from typing import Any, Sequence | |
| import firebase_admin | |
| import gradio as gr | |
| import pytz | |
| from dotenv import load_dotenv | |
| from firebase_admin import credentials, firestore | |
| import tensorflow as tf | |
| import torch | |
| from gradio import CSVLogger, FlaggingCallback | |
| from gradio.components import Component | |
| from transformers import DebertaV2Tokenizer, TFAutoModelForSequenceClassification, AutoModelForSequenceClassification | |
| USE_TENSORFLOW = True | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| CLASSES = { | |
| 'yes': 0, | |
| 'irrelevant': 1, | |
| 'no': 2, | |
| } | |
| tokenizer = DebertaV2Tokenizer.from_pretrained('cross-encoder/nli-deberta-v3-base', do_lower_case=True) | |
| model = TFAutoModelForSequenceClassification.from_pretrained('MrPio/TheSeagullStory-nli-deberta-v3-base', | |
| dtype=tf.float16) if USE_TENSORFLOW else AutoModelForSequenceClassification.from_pretrained( | |
| 'MrPio/TheSeagullStory-nli-deberta-v3-base') | |
| if not USE_TENSORFLOW: | |
| model.eval() | |
| if torch.cuda.is_available(): | |
| model.half() | |
| story = open('story.txt').read().replace("\n\n", "\n").replace("\n", " ").strip() | |
| load_dotenv() | |
| cred = credentials.Certificate({ | |
| "type": "service_account", | |
| "project_id": "scheda-dnd", | |
| "private_key": os.environ.get("PRIVATE_KEY"), | |
| "private_key_id": "948666ca297742d06eebd6a97f77f750d033c208", | |
| "client_email": "firebase-adminsdk-is4pg@scheda-dnd.iam.gserviceaccount.com", | |
| "client_id": "105104335855166557589", | |
| "auth_uri": "https://accounts.google.com/o/oauth2/auth", | |
| "token_uri": "https://oauth2.googleapis.com/token", | |
| "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", | |
| "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/firebase-adminsdk-is4pg%40scheda-dnd.iam.gserviceaccount.com", | |
| "universe_domain": "googleapis.com", | |
| }) | |
| firebase_admin.initialize_app(cred) | |
| db = firestore.client() | |
| def ask(question): | |
| input = tokenizer(story, question, truncation=True, padding=True, return_tensors='tf' if USE_TENSORFLOW else 'pt') | |
| if not USE_TENSORFLOW: | |
| input = {key: value.to(device) for key, value in input.items()} | |
| output = model(**input) | |
| prediction = torch.softmax(output.logits, 1).squeeze() | |
| return {c: round(prediction[i].item(), 3) for c, i in CLASSES.items()} | |
| else: | |
| output = model(input, training=False) | |
| prediction = tf.nn.softmax(output.logits, axis=-1).numpy().squeeze() | |
| return {c: round(prediction[i], 3) for c, i in CLASSES.items()} | |
| class Flagger(FlaggingCallback): | |
| def __init__(self): | |
| self.base_logger = CSVLogger() | |
| self.flags_collection = db.collection("other_apps/seagull_story/seagull_story_flags") | |
| def setup(self, components: Sequence[Component], flagging_dir: str): | |
| self.base_logger.setup(components=components, flagging_dir=flagging_dir) | |
| def flag(self, flag_data: list[Any], flag_option: str | None = None, username: str | None = None) -> int: | |
| if len(flag_data[0]) > 3 and 'confidences' in flag_data[1]: | |
| self.flags_collection.document(str(time.time_ns())).set({ | |
| "question": flag_data[0], | |
| "prediction": flag_data[1]['label'], | |
| "confidences": flag_data[1]['confidences'], | |
| "flag": flag_option, | |
| "timestamp": datetime.now(pytz.utc), | |
| "username": username, | |
| }) | |
| return self.base_logger.flag(flag_data=flag_data, flag_option=flag_option, username=username) | |
| gradio = gr.Interface( | |
| ask, | |
| inputs=[gr.Textbox(value="", label="Your question, as an affirmative sentence:")], | |
| outputs=[gr.Label(label="Answer", num_top_classes=3)], | |
| title="The Seagull Story", | |
| flagging_mode='manual', | |
| flagging_callback=Flagger(), | |
| flagging_options=['Yes', 'No', 'Irrelevant'], | |
| description="“ Albert and Dave find themselves on the pier. They go to a nearby restaurant where Albert orders " | |
| "seagull meat. The waiter promptly serves Albert the meal. After taking a bite, he realizes " | |
| "something. Albert pulls a gun out of his ruined jacket and shoots himself. ”\n\nWhy did Albert shoot " | |
| "himself?\n\nCan you unravel the truth behind this epilogue by asking only yes/no questions?\n\nPlease be specific about the time period you have in mind with your question.", | |
| article='Please refrain from embarrassing DeBERTa with dumb questions.\n\nCheck the repository for more detail: https://github.com/MrPio/The-Seagull-Story', | |
| examples=['Albert shoot himself for a reason', | |
| 'Dave has a watch on his wrist', | |
| 'Albert and Dave came to the pier on their own'] | |
| ) | |
| if __name__ == "__main__": | |
| gradio.launch(share=True) | |