File size: 2,304 Bytes
8c5f9df
 
72f5fca
 
 
e177968
 
72f5fca
 
 
 
 
 
 
 
 
 
 
 
 
 
cc16683
72f5fca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e177968
 
 
 
 
 
72f5fca
 
 
cc16683
4ba066a
72f5fca
e177968
 
72f5fca
 
e177968
 
72f5fca
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import gradio as gr
from huggingface_hub import InferenceClient
from transformers import AutoModelForSequenceClassification, AutoConfig, AutoTokenizer
import torch
import numpy as np


MODEL_NAME = "URaBOT2024/debertaV3_FullFeature"

# Load pre-trained models and tokenizers
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels = 2)
config = AutoConfig.from_pretrained(MODEL_NAME)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# Set hardware target for model
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
model.to(device)
model.eval()    # Set model to evaluation mode



def verify(display_name, tweet_content, is_verified, likes):
    
    input = tweet_content + tokenizer.sep_token + display_name + tokenizer.sep_token + is_verified + tokenizer.sep_token + likes
    tokenized_input = tokenizer(input, return_tensors='pt', padding=True, truncation=True).to(device)
    with torch.no_grad():
        outputs = model(**tokenized_input)
    
    # Determine classification
    sigmoid = (1 / (1 + np.exp(-outputs.logits.detach().numpy()))).tolist()[0]

    # Apply Platt Scaling
    # if USE_PS:
    #     sigmoid = [(1/(1+ math.exp(-(A * x + B)))) for x in sigmoid] 
    
    # Find majority class
    label = np.argmax(outputs.logits.detach().numpy(), axis=-1).item()


    # Return sigmoid-ish value for classification. Can instead return label for strict 0/1 binary classification
    if label == 0:
        return 1 -  sigmoid[0]
    else:
        return sigmoid[1]



"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
# Set up the Gradio Interface
iface = gr.Interface(
    fn=verify,              # Function to process input
    inputs=[gr.Textbox(label= "Display Name"), gr.Textbox(label= "Tweet Content"), gr.Textbox(label= "IsVerified"), gr.Textbox(label= "Number of Likes")],    # Input type (Textbox for text)
    outputs=gr.Textbox(),   # Output type (Textbox for generated text)
    live=True                        # Optional: To update the result as you type
)

# Launch the API on a specific port


if __name__ == "__main__":
    iface.launch(share=True)  # share=True will give you a public URL to use the API

    # demo.launch()