Update app.py
Browse files
app.py
CHANGED
|
@@ -6,11 +6,11 @@ import torch
|
|
| 6 |
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
|
| 7 |
|
| 8 |
# ---------------------------
|
| 9 |
-
# Load your model
|
| 10 |
# ---------------------------
|
| 11 |
-
|
| 12 |
-
tokenizer = RobertaTokenizerFast.from_pretrained(
|
| 13 |
-
model = RobertaForSequenceClassification.from_pretrained(
|
| 14 |
model.eval()
|
| 15 |
|
| 16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
@@ -20,7 +20,13 @@ model.to(device)
|
|
| 20 |
# Util: classify a single text
|
| 21 |
# ---------------------------
|
| 22 |
def predict_label(text: str):
|
| 23 |
-
inputs = tokenizer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
with torch.no_grad():
|
| 25 |
outputs = model(**inputs)
|
| 26 |
logits = outputs.logits
|
|
@@ -84,14 +90,14 @@ def clear_chat():
|
|
| 84 |
return render_chat([]), []
|
| 85 |
|
| 86 |
# ---------------------------
|
| 87 |
-
# Custom CSS (Pink + Blue ONLY
|
| 88 |
# ---------------------------
|
| 89 |
CSS = """
|
| 90 |
* { box-sizing: border-box; }
|
| 91 |
:root {
|
| 92 |
-
--bg-gradient: linear-gradient(135deg, #ff99cc, #66ccff);
|
| 93 |
-
--bubble-a: #ff66b2;
|
| 94 |
-
--bubble-b: #3399ff;
|
| 95 |
--text-light: #f9f9f9;
|
| 96 |
--chip-safe: #00e676;
|
| 97 |
--chip-unsafe: #ff5252;
|
|
|
|
| 6 |
from transformers import RobertaTokenizerFast, RobertaForSequenceClassification
|
| 7 |
|
| 8 |
# ---------------------------
|
| 9 |
+
# Load your model (conversation learning style)
|
| 10 |
# ---------------------------
|
| 11 |
+
MODEL_PATH = "Alifjo123/robertaBase_messaging_100k" # your HuggingFace model
|
| 12 |
+
tokenizer = RobertaTokenizerFast.from_pretrained(MODEL_PATH)
|
| 13 |
+
model = RobertaForSequenceClassification.from_pretrained(MODEL_PATH)
|
| 14 |
model.eval()
|
| 15 |
|
| 16 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 20 |
# Util: classify a single text
|
| 21 |
# ---------------------------
|
| 22 |
def predict_label(text: str):
|
| 23 |
+
inputs = tokenizer(
|
| 24 |
+
text,
|
| 25 |
+
truncation=True,
|
| 26 |
+
padding=True,
|
| 27 |
+
max_length=512,
|
| 28 |
+
return_tensors="pt"
|
| 29 |
+
).to(device)
|
| 30 |
with torch.no_grad():
|
| 31 |
outputs = model(**inputs)
|
| 32 |
logits = outputs.logits
|
|
|
|
| 90 |
return render_chat([]), []
|
| 91 |
|
| 92 |
# ---------------------------
|
| 93 |
+
# Custom CSS (Pink + Blue ONLY)
|
| 94 |
# ---------------------------
|
| 95 |
CSS = """
|
| 96 |
* { box-sizing: border-box; }
|
| 97 |
:root {
|
| 98 |
+
--bg-gradient: linear-gradient(135deg, #ff99cc, #66ccff);
|
| 99 |
+
--bubble-a: #ff66b2;
|
| 100 |
+
--bubble-b: #3399ff;
|
| 101 |
--text-light: #f9f9f9;
|
| 102 |
--chip-safe: #00e676;
|
| 103 |
--chip-unsafe: #ff5252;
|