Commit
·
4a4406e
1
Parent(s):
edec6bb
Update app.py
Browse files
app.py
CHANGED
|
@@ -24,16 +24,20 @@ def sentiment_analysis_DB(input):
|
|
| 24 |
pad_to_max_length=True,
|
| 25 |
return_token_type_ids=True
|
| 26 |
)
|
| 27 |
-
ids = inputs['input_ids']
|
| 28 |
-
mask = inputs['attention_mask']
|
| 29 |
-
token_type_ids = inputs["token_type_ids"]
|
|
|
|
|
|
|
| 30 |
output = model_DB(ids, mask, token_type_ids)
|
| 31 |
-
|
| 32 |
-
final_outputs =
|
|
|
|
| 33 |
if final_outputs == True:
|
| 34 |
result = 1
|
| 35 |
else:
|
| 36 |
result = 0
|
|
|
|
| 37 |
return result
|
| 38 |
|
| 39 |
# Streamlit app
|
|
|
|
| 24 |
pad_to_max_length=True,
|
| 25 |
return_token_type_ids=True
|
| 26 |
)
|
| 27 |
+
ids = torch.tensor([inputs['input_ids']]) # Convert to PyTorch tensor
|
| 28 |
+
mask = torch.tensor([inputs['attention_mask']]) # Convert to PyTorch tensor
|
| 29 |
+
token_type_ids = torch.tensor([inputs["token_type_ids"]]) # Convert to PyTorch tensor
|
| 30 |
+
|
| 31 |
+
# Assuming model_DB is a PyTorch model
|
| 32 |
output = model_DB(ids, mask, token_type_ids)
|
| 33 |
+
|
| 34 |
+
final_outputs = output[0].item() # Extract the scalar value
|
| 35 |
+
|
| 36 |
if final_outputs == True:
|
| 37 |
result = 1
|
| 38 |
else:
|
| 39 |
result = 0
|
| 40 |
+
|
| 41 |
return result
|
| 42 |
|
| 43 |
# Streamlit app
|