File size: 2,304 Bytes
8c5f9df 72f5fca e177968 72f5fca cc16683 72f5fca e177968 72f5fca cc16683 4ba066a 72f5fca e177968 72f5fca e177968 72f5fca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
import torch
import numpy as np
MODEL_NAME = "URaBOT2024/debertaV3_FullFeature"
# Load pre-trained models and tokenizers
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels = 2)
config = AutoConfig.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Set hardware target for model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval() # Set model to evaluation mode
def verify(display_name, tweet_content, is_verified, likes):
input = tweet_content + tokenizer.sep_token + display_name + tokenizer.sep_token + is_verified + tokenizer.sep_token + likes
tokenized_input = tokenizer(input, return_tensors='pt', padding=True, truncation=True).to(device)
with torch.no_grad():
outputs = model(**tokenized_input)
# Determine classification
sigmoid = (1 / (1 + np.exp(-outputs.logits.detach().numpy()))).tolist()[0]
# Apply Platt Scaling
# if USE_PS:
# sigmoid = [(1/(1+ math.exp(-(A * x + B)))) for x in sigmoid]
# Find majority class
label = np.argmax(outputs.logits.detach().numpy(), axis=-1).item()
# Return sigmoid-ish value for classification. Can instead return label for strict 0/1 binary classification
if label == 0:
return 1 - sigmoid[0]
else:
return sigmoid[1]
"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# Set up the Gradio Interface
iface = gr.Interface(
fn=verify, # Function to process input
inputs=[gr.Textbox(label= "Display Name"), gr.Textbox(label= "Tweet Content"), gr.Textbox(label= "IsVerified"), gr.Textbox(label= "Number of Likes")], # Input type (Textbox for text)
outputs=gr.Textbox(), # Output type (Textbox for generated text)
live=True # Optional: To update the result as you type
)
# Launch the API on a specific port
if __name__ == "__main__":
iface.launch(share=True) # share=True will give you a public URL to use the API
# demo.launch() |