Update app.py
Browse files
app.py
CHANGED
|
@@ -71,6 +71,9 @@ bert_tokenizer = AutoTokenizer.from_pretrained("my_finetuned_model")
|
|
| 71 |
bert_model = AutoModelForSequenceClassification.from_pretrained("my_finetuned_model")
|
| 72 |
bert_model.eval()
|
| 73 |
|
|
|
|
|
|
|
|
|
|
| 74 |
# --- Pretvaranje teksta u indekse za CNN i GRU ---
|
| 75 |
def text_to_indices(text, max_len=100):
|
| 76 |
tokens = text.lower().split()
|
|
@@ -92,7 +95,7 @@ def predict_svm(text):
|
|
| 92 |
proba = svm_pipeline.predict_proba([text])[0]
|
| 93 |
pred = svm_pipeline.classes_[proba.argmax()]
|
| 94 |
print(f"SVM predikcija: {pred}, povjerenje: {proba.max():.2f}")
|
| 95 |
-
return f"{pred} (p={proba.max():.2f})"
|
| 96 |
|
| 97 |
def predict_cnn(text):
|
| 98 |
print(f"Predikcija CNN za tekst: {text}")
|
|
@@ -104,7 +107,7 @@ def predict_cnn(text):
|
|
| 104 |
pred = torch.argmax(probs, dim=1).item()
|
| 105 |
confidence = probs[0][pred].item()
|
| 106 |
print(f"CNN predikcija: {pred}, povjerenje: {confidence:.2f}")
|
| 107 |
-
return f"{pred} (p={confidence:.2f})"
|
| 108 |
|
| 109 |
def predict_gru(text):
|
| 110 |
print(f"Predikcija GRU za tekst: {text}")
|
|
@@ -116,7 +119,7 @@ def predict_gru(text):
|
|
| 116 |
pred = torch.argmax(probs, dim=1).item()
|
| 117 |
confidence = probs[0][pred].item()
|
| 118 |
print(f"GRU predikcija: {pred}, povjerenje: {confidence:.2f}")
|
| 119 |
-
return f"{pred} (p={confidence:.2f})"
|
| 120 |
|
| 121 |
def predict_bert(text):
|
| 122 |
print(f"Predikcija BERTić za tekst: {text}")
|
|
@@ -128,7 +131,7 @@ def predict_bert(text):
|
|
| 128 |
pred = torch.argmax(probs, dim=1).item()
|
| 129 |
confidence = probs[0][pred].item()
|
| 130 |
print(f"BERTić predikcija: {pred}, povjerenje: {confidence:.2f}")
|
| 131 |
-
return f"{pred} (p={confidence:.2f})"
|
| 132 |
|
| 133 |
# --- Gradio sučelje ---
|
| 134 |
def predict_all(text):
|
|
@@ -149,10 +152,7 @@ demo = gr.Interface(
|
|
| 149 |
gr.Textbox(label="BERTić")
|
| 150 |
],
|
| 151 |
title="Demo klasifikacije teksta",
|
| 152 |
-
description=
|
| 153 |
-
"Predikcije koriste SVM, CNN, GRU i BERTić modele.\n\n"
|
| 154 |
-
"Napomena: 0 = pozitivno, 1 = neutralno, 2 = negativno."
|
| 155 |
-
)
|
| 156 |
)
|
| 157 |
|
| 158 |
if __name__ == "__main__":
|
|
|
|
| 71 |
bert_model = AutoModelForSequenceClassification.from_pretrained("my_finetuned_model")
|
| 72 |
bert_model.eval()
|
| 73 |
|
| 74 |
+
# --- Rječnik za mapiranje oznaka ---
|
| 75 |
+
label_names = {0: 'pozitivno', 1: 'neutralno', 2: 'negativno'}
|
| 76 |
+
|
| 77 |
# --- Pretvaranje teksta u indekse za CNN i GRU ---
|
| 78 |
def text_to_indices(text, max_len=100):
|
| 79 |
tokens = text.lower().split()
|
|
|
|
| 95 |
proba = svm_pipeline.predict_proba([text])[0]
|
| 96 |
pred = svm_pipeline.classes_[proba.argmax()]
|
| 97 |
print(f"SVM predikcija: {pred}, povjerenje: {proba.max():.2f}")
|
| 98 |
+
return f"{label_names[pred]} (p={proba.max():.2f})"
|
| 99 |
|
| 100 |
def predict_cnn(text):
|
| 101 |
print(f"Predikcija CNN za tekst: {text}")
|
|
|
|
| 107 |
pred = torch.argmax(probs, dim=1).item()
|
| 108 |
confidence = probs[0][pred].item()
|
| 109 |
print(f"CNN predikcija: {pred}, povjerenje: {confidence:.2f}")
|
| 110 |
+
return f"{label_names[pred]} (p={confidence:.2f})"
|
| 111 |
|
| 112 |
def predict_gru(text):
|
| 113 |
print(f"Predikcija GRU za tekst: {text}")
|
|
|
|
| 119 |
pred = torch.argmax(probs, dim=1).item()
|
| 120 |
confidence = probs[0][pred].item()
|
| 121 |
print(f"GRU predikcija: {pred}, povjerenje: {confidence:.2f}")
|
| 122 |
+
return f"{label_names[pred]} (p={confidence:.2f})"
|
| 123 |
|
| 124 |
def predict_bert(text):
|
| 125 |
print(f"Predikcija BERTić za tekst: {text}")
|
|
|
|
| 131 |
pred = torch.argmax(probs, dim=1).item()
|
| 132 |
confidence = probs[0][pred].item()
|
| 133 |
print(f"BERTić predikcija: {pred}, povjerenje: {confidence:.2f}")
|
| 134 |
+
return f"{label_names[pred]} (p={confidence:.2f})"
|
| 135 |
|
| 136 |
# --- Gradio sučelje ---
|
| 137 |
def predict_all(text):
|
|
|
|
| 152 |
gr.Textbox(label="BERTić")
|
| 153 |
],
|
| 154 |
title="Demo klasifikacije teksta",
|
| 155 |
+
description="Predikcije koriste SVM, CNN, GRU i BERTić modele."
|
|
|
|
|
|
|
|
|
|
| 156 |
)
|
| 157 |
|
| 158 |
if __name__ == "__main__":
|