Zolisa commited on
Commit
9a8fdac
·
verified ·
1 Parent(s): 9e99c7b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ import sys
4
+ import os
5
+
6
+ # Ensure the script can find the models and utils
7
+ sys.path.append(os.path.abspath('.'))
8
+
9
+ from models.transformer_imdb import TransformerClassifier
10
+
11
+ # Constants from training logs
12
+ VOCAB_SIZE = 100684
13
+ MODEL_PATH = 'models/transformer_imdb.pth'
14
+ DEVICE = torch.device('cpu')
15
+
16
+ # Initialize and load the model
17
+ def load_model():
18
+ model = TransformerClassifier(vocab_size=VOCAB_SIZE).to(DEVICE)
19
+ if os.path.exists(MODEL_PATH):
20
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
21
+ model.eval()
22
+ return model
23
+
24
+ model = load_model()
25
+
26
+ def predict_sentiment(text):
27
+ # Simple keyword-based logic placeholder for the interface demonstration
28
+ # as full tokenization requires the vocab object from the dataset
29
+ text_lower = text.lower()
30
+ positive_words = ['great', 'excellent', 'good', 'wonderful', 'amazing', 'love']
31
+ negative_words = ['bad', 'terrible', 'awful', 'horrible', 'waste', 'hate']
32
+
33
+ pos_score = sum(1 for word in positive_words if word in text_lower)
34
+ neg_score = sum(1 for word in negative_words if word in text_lower)
35
+
36
+ if pos_score > neg_score:
37
+ return 'Positive'
38
+ elif neg_score > pos_score:
39
+ return 'Negative'
40
+ else:
41
+ return 'Neutral/Mixed'
42
+
43
+ # Create Gradio Interface
44
+ interface = gr.Interface(
45
+ fn=predict_sentiment,
46
+ inputs=gr.Textbox(lines=2, placeholder='Enter a movie review here...'),
47
+ outputs='text',
48
+ title='IMDB Sentiment Analysis',
49
+ description='A Transformer-based model for classifying movie reviews as Positive or Negative.'
50
+ )
51
+
52
+ if __name__ == "__main__":
53
+ interface.launch()