ma4389 commited on
Commit
e4f73ac
·
verified ·
1 Parent(s): 0aa165e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +95 -95
app.py CHANGED
@@ -1,95 +1,95 @@
1
- import gradio as gr
2
- import torch
3
- from transformers import T5Tokenizer
4
- import torch.nn as nn
5
- from transformers import T5EncoderModel
6
- import re
7
- from nltk.tokenize import word_tokenize
8
- from nltk.corpus import stopwords
9
- from nltk.stem import WordNetLemmatizer
10
- import nltk
11
-
12
- # Download NLTK resources (only first time)
13
- nltk.download('punkt')
14
- nltk.download('stopwords')
15
- nltk.download('wordnet')
16
- nltk.download('omw-1.4')
17
-
18
- # Initialize preprocessing tools
19
- stop_words = set(stopwords.words('english'))
20
- lemmatizer = WordNetLemmatizer()
21
-
22
- # Preprocessing function
23
- def preprocess_text(text):
24
- # Remove non-alphabet characters
25
- text = re.sub(r'[^A-Za-z\s]', '', text)
26
- # Remove URLs
27
- text = re.sub(r'http\S+|www\S+|https\S+', '', text)
28
- # Normalize whitespace
29
- text = re.sub(r'\s+', ' ', text).strip()
30
- # Lowercase
31
- text = text.lower()
32
- # Tokenize
33
- tokens = word_tokenize(text)
34
- # Remove stopwords
35
- tokens = [word for word in tokens if word not in stop_words]
36
- # Lemmatize
37
- tokens = [lemmatizer.lemmatize(word) for word in tokens]
38
- # Re-join
39
- return ' '.join(tokens)
40
-
41
- # Model class
42
- class T5_regression(nn.Module):
43
- def __init__(self):
44
- super(T5_regression, self).__init__()
45
- self.t5 = T5EncoderModel.from_pretrained("t5-base")
46
- self.fc = nn.Linear(self.t5.config.d_model, 1)
47
- self.relu = nn.ReLU()
48
-
49
- def forward(self, input_ids, attention_mask):
50
- output = self.t5(input_ids=input_ids, attention_mask=attention_mask)
51
- pooled_output = output.last_hidden_state[:, 0, :]
52
- rating = self.fc(pooled_output)
53
- return rating.squeeze(-1)
54
-
55
- # Load tokenizer and model
56
- tokenizer = T5Tokenizer.from_pretrained("t5-base")
57
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
- model = T5_regression().to(device)
59
-
60
- # Load trained weights
61
- model.load_state_dict(torch.load("best_model.pth", map_location=device))
62
- model.eval()
63
-
64
- # Prediction function
65
- def predict_rating(review_text):
66
- # Preprocess review
67
- clean_text = preprocess_text(review_text)
68
-
69
- encoding = tokenizer(
70
- clean_text,
71
- truncation=True,
72
- padding='max_length',
73
- max_length=128,
74
- return_tensors='pt'
75
- )
76
-
77
- input_ids = encoding['input_ids'].to(device)
78
- attention_mask = encoding['attention_mask'].to(device)
79
-
80
- with torch.no_grad():
81
- output = model(input_ids, attention_mask)
82
- rating = output.item()
83
-
84
- return round(rating, 1)
85
-
86
- # Gradio UI
87
- iface = gr.Interface(
88
- fn=predict_rating,
89
- inputs=gr.Textbox(lines=4, placeholder="Enter your review here..."),
90
- outputs=gr.Number(label="Predicted Rating"),
91
- title="Review Rating Predictor",
92
- description="Predicts the rating of a mobile app review using a fine-tuned T5 regression model."
93
- )
94
-
95
- iface.launch()
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import T5Tokenizer
4
+ import torch.nn as nn
5
+ from transformers import T5EncoderModel
6
+ import re
7
+ from nltk.tokenize import word_tokenize
8
+ from nltk.corpus import stopwords
9
+ from nltk.stem import WordNetLemmatizer
10
+ import nltk
11
+
12
+ # Download NLTK resources (only first time)
13
+ nltk.download('punkt_tab')
14
+ nltk.download('stopwords')
15
+ nltk.download('wordnet')
16
+ nltk.download('omw-1.4')
17
+
18
+ # Initialize preprocessing tools
19
+ stop_words = set(stopwords.words('english'))
20
+ lemmatizer = WordNetLemmatizer()
21
+
22
+ # Preprocessing function
23
+ def preprocess_text(text):
24
+ # Remove non-alphabet characters
25
+ text = re.sub(r'[^A-Za-z\s]', '', text)
26
+ # Remove URLs
27
+ text = re.sub(r'http\S+|www\S+|https\S+', '', text)
28
+ # Normalize whitespace
29
+ text = re.sub(r'\s+', ' ', text).strip()
30
+ # Lowercase
31
+ text = text.lower()
32
+ # Tokenize
33
+ tokens = word_tokenize(text)
34
+ # Remove stopwords
35
+ tokens = [word for word in tokens if word not in stop_words]
36
+ # Lemmatize
37
+ tokens = [lemmatizer.lemmatize(word) for word in tokens]
38
+ # Re-join
39
+ return ' '.join(tokens)
40
+
41
+ # Model class
42
+ class T5_regression(nn.Module):
43
+ def __init__(self):
44
+ super(T5_regression, self).__init__()
45
+ self.t5 = T5EncoderModel.from_pretrained("t5-base")
46
+ self.fc = nn.Linear(self.t5.config.d_model, 1)
47
+ self.relu = nn.ReLU()
48
+
49
+ def forward(self, input_ids, attention_mask):
50
+ output = self.t5(input_ids=input_ids, attention_mask=attention_mask)
51
+ pooled_output = output.last_hidden_state[:, 0, :]
52
+ rating = self.fc(pooled_output)
53
+ return rating.squeeze(-1)
54
+
55
+ # Load tokenizer and model
56
+ tokenizer = T5Tokenizer.from_pretrained("t5-base")
57
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
58
+ model = T5_regression().to(device)
59
+
60
+ # Load trained weights
61
+ model.load_state_dict(torch.load("best_model.pth", map_location=device))
62
+ model.eval()
63
+
64
+ # Prediction function
65
+ def predict_rating(review_text):
66
+ # Preprocess review
67
+ clean_text = preprocess_text(review_text)
68
+
69
+ encoding = tokenizer(
70
+ clean_text,
71
+ truncation=True,
72
+ padding='max_length',
73
+ max_length=128,
74
+ return_tensors='pt'
75
+ )
76
+
77
+ input_ids = encoding['input_ids'].to(device)
78
+ attention_mask = encoding['attention_mask'].to(device)
79
+
80
+ with torch.no_grad():
81
+ output = model(input_ids, attention_mask)
82
+ rating = output.item()
83
+
84
+ return round(rating, 1)
85
+
86
+ # Gradio UI
87
+ iface = gr.Interface(
88
+ fn=predict_rating,
89
+ inputs=gr.Textbox(lines=4, placeholder="Enter your review here..."),
90
+ outputs=gr.Number(label="Predicted Rating"),
91
+ title="Review Rating Predictor",
92
+ description="Predicts the rating of a mobile app review using a fine-tuned T5 regression model."
93
+ )
94
+
95
+ iface.launch()