slickdata commited on
Commit
610c96a
·
1 Parent(s): a4031fc

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import libraries
2
+ import os
3
+ import uuid
4
+ import pandas as pd
5
+ import numpy as np
6
+ from scipy.special import softmax
7
+ import gradio as gr
8
+
9
+ from google.colab import drive
10
+ from datasets import load_dataset
11
+ from sklearn.model_selection import train_test_split
12
+ import torch
13
+ from transformers import AutoTokenizer
14
+ from transformers import AutoConfig
15
+ from transformers import AutoModelForSequenceClassification
16
+ from transformers import TFAutoModelForSequenceClassification
17
+ from transformers import IntervalStrategy
18
+ from transformers import TrainingArguments
19
+ from transformers import EarlyStoppingCallback
20
+ from transformers import pipeline
21
+ from transformers import TrainingArguments
22
+ from transformers import Trainer
23
+ from torch import nn
24
+ from transformers import RobertaTokenizer, RobertaForSequenceClassification
25
+
26
+
27
+
28
+ # Define the model path where the pre-trained model is saved on the Hugging Face model hub
29
+ model_path = "slickdata/finetuned-Sentiment-classfication-ROBERTA-model"
30
+
31
+ # Initialize the tokenizer for the pre-trained model
32
+ tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
33
+
34
+ # Load the configuration for the pre-trained model
35
+ config = AutoConfig.from_pretrained(model_path)
36
+
37
+ # Load the pre-trained model
38
+ model = AutoModelForSequenceClassification.from_pretrained(model_path)
39
+
40
+ # Define a function to preprocess the text data
41
+ def preprocess(text):
42
+ new_text = []
43
+ # Replace user mentions with '@user'
44
+ for t in text.split(" "):
45
+ t = '@user' if t.startswith('@') and len(t) > 1 else t
46
+ # Replace links with 'http'
47
+ t = 'http' if t.startswith('http') else t
48
+ new_text.append(t)
49
+ # Join the preprocessed text
50
+ return " ".join(new_text)
51
+
52
+ # Define a function to perform sentiment analysis on the input text
53
+ def sentiment_analysis(text):
54
+ # Preprocess the input text
55
+ text = preprocess(text)
56
+
57
+ # Tokenize the input text using the pre-trained tokenizer
58
+ encoded_input = tokenizer(text, return_tensors='pt')
59
+
60
+ # Feed the tokenized input to the pre-trained model and obtain output
61
+ output = model(**encoded_input)
62
+
63
+ # Obtain the prediction scores for the output
64
+ scores_ = output[0][0].detach().numpy()
65
+
66
+ # Apply softmax activation function to obtain probability distribution over the labels
67
+ scores_ = softmax(scores_)
68
+
69
+ # Format the output dictionary with the predicted scores
70
+ labels = ['Negative', 'Neutral', 'Positive']
71
+ scores = {l:float(s) for (l,s) in zip(labels, scores_) }
72
+
73
+ # Return the scores
74
+ return scores
75
+
76
+ # Define a Gradio interface to interact with the model
77
+ demo = gr.Interface(
78
+ fn=sentiment_analysis, # Function to perform sentiment analysis
79
+ inputs=gr.Textbox(placeholder="Write your tweet here..."), # Text input field
80
+ outputs="label", # Output type (here, we only display the label with the highest score)
81
+ interpretation="default", # Interpretation mode
82
+ examples=[["This is wonderful!"]]) # Example input(s) to display on the interface
83
+
84
+ # Launch the Gradio interface
85
+ demo.launch(share=True, debug=True)