demo / app.py
miasambolec's picture
Update app.py
1260695 verified
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from huggingface_hub import hf_hub_download
import gradio as gr
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
import torch.nn as nn
import pickle
import numpy as np
import re
import fasttext
svm_repo_id = "HighFive-OPJ/svm-sentiment-model"
svm_model_path = hf_hub_download(repo_id=svm_repo_id, filename="svm_model.pkl")
with open(svm_model_path, "rb") as f:
svm_model = pickle.load(f)
vectorizer_path = hf_hub_download(repo_id=svm_repo_id, filename="vectorizer.pkl")
with open(vectorizer_path, "rb") as f:
vectorizer = pickle.load(f)
fasttext_path = hf_hub_download(
repo_id="HighFive-OPJ/Deep_Learning",
filename="FastText.bin",
repo_type="dataset"
)
ft_model = fasttext.load_model(fasttext_path)
class LSTMClassifier(nn.Module):
def __init__(self, input_dim=300, hidden_dim=256, num_classes=3):
super().__init__()
self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
self.fc = nn.Linear(hidden_dim * 2, num_classes)
def forward(self, x):
_, (hn, _) = self.lstm(x)
hn = torch.cat((hn[-2], hn[-1]), dim=1)
out = self.fc(hn)
return out
lstm_repo_id = "HighFive-OPJ/lstm-sentiment-model"
lstm_model_path = hf_hub_download(repo_id=lstm_repo_id, filename="fasttext_lstm.pt")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
lstm_model = LSTMClassifier()
lstm_model.load_state_dict(torch.load(lstm_model_path, map_location=device))
lstm_model.to(device)
lstm_model.eval()
bert_repo_id = "HighFive-OPJ/bertic_sentiment"
bert_tokenizer = AutoTokenizer.from_pretrained(bert_repo_id)
bert_model = AutoModelForSequenceClassification.from_pretrained(bert_repo_id)
bert_model.to(device)
bert_model.eval()
def preprocess_text(text):
text = text.lower()
text = re.sub(r"[^a-zA-Z\s]", "", text).strip()
return text
def text_to_fasttext_tensor(text, max_len=200):
tokens = preprocess_text(text).split()
vectors = []
for t in tokens[:max_len]:
vec = ft_model.get_word_vector(t)
vectors.append(vec)
while len(vectors) < max_len:
vectors.append(np.zeros(300))
return torch.tensor([vectors], dtype=torch.float32).to(device)
def predict_with_svm(text):
transformed = vectorizer.transform([text])
prediction = svm_model.predict(transformed)
return int(prediction[0])
def predict_with_lstm(text):
input_tensor = text_to_fasttext_tensor(text)
with torch.no_grad():
outputs = lstm_model(input_tensor)
pred = torch.argmax(outputs, dim=1).item()
return pred
def predict_with_bert(text):
inputs = bert_tokenizer([text], padding=True, truncation=True, max_length=512, return_tensors="pt").to(device)
with torch.no_grad():
outputs = bert_model(**inputs)
logits = outputs.logits
predictions = logits.argmax(axis=-1).cpu().numpy()
bert_score = int(predictions[0])
if bert_score <= 2:
return 0
elif bert_score == 3:
return 1
else:
return 2
def analyze_sentiment(text):
try:
svm_result = predict_with_svm(text)
except Exception as e:
svm_result = f"Error: {str(e)}"
try:
lstm_result = predict_with_lstm(text)
except Exception as e:
lstm_result = f"Error: {str(e)}"
try:
bert_result = predict_with_bert(text)
except Exception as e:
bert_result = f"Error: {str(e)}"
try:
scores = []
for r in [svm_result, lstm_result, bert_result]:
if isinstance(r, int):
scores.append(r)
average = np.mean(scores) if scores else float("nan")
stats = f"Average Score (0=Pos,1=Neg,2=Neu): {average:.2f}\n"
except Exception as e:
stats = f"Error calculating stats: {str(e)}"
def format_output(result):
return convert_to_stars(result) if isinstance(result, int) else result
return (
format_output(svm_result),
format_output(lstm_result),
format_output(bert_result),
stats
)
def convert_to_stars(score):
star_map = {0: 5, 1: 1, 2: 3}
stars = star_map.get(score, 3)
return "★" * stars + "☆" * (5 - stars)
def process_input(text):
if not text.strip():
return ("", "", "", "Please enter valid text.")
try:
return analyze_sentiment(text)
except Exception as e:
error_message = f"Error during sentiment analysis:\n{str(e)}"
return ("error", "error", "error", error_message)
with gr.Blocks() as demo:
gr.Markdown("# Sentiment Analysis Demo")
gr.Markdown("""
Enter a review and see how different models evaluate its sentiment! This app uses:
- SVM for classic machine learning
- LSTM for deep learning (using FastText)
- BERTić for transformer-based analysis
Rating guide:
5 ★ → positive sentiment
3 ★ → neutral sentiment
1 ★ → negative sentiment
""")
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Enter your review:", lines=3)
analyze_button = gr.Button("Analyze Sentiment")
with gr.Column():
svm_output = gr.Textbox(label="SVM", interactive=False)
lstm_output = gr.Textbox(label="LSTM", interactive=False)
bert_output = gr.Textbox(label="BERTić", interactive=False)
stats_output = gr.Textbox(label="Statistics", interactive=False)
analyze_button.click(
process_input,
inputs=[input_text],
outputs=[svm_output, lstm_output, bert_output, stats_output]
)
demo.launch()