Spaces:
Sleeping
Sleeping
File size: 5,703 Bytes
560eda0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 |
from typing import Dict, Any, Optional
from langgraph_config.langgraph_state import OverallState
from langgraph.config import get_stream_writer
HF_MODEL_NAME = "FlukeTJ/distilbert-base-thai-sentiment"
CONF_THRESH = 0.50
NEU_MARGIN = 0.07
TOXIC_WORDS = {
"โง่", "ควาย", "ห่วย", "มั่ว", "ไร้สาระ", "เลิกพูดมาก", "อย่ามากวน",
"หน้าบูด", "ป้า", "ส้นตีน", "เหี้ย", "สัส", "ควย", "ปัญญาอ่อน", "ทุเรศ", "ขยะ",
}
POSITIVE_CUES = {
"นับถือ", "ชื่นชม", "ภูมิใจ", "สุดยอด", "เก่งมาก", "ดีมาก", "ยอดเยี่ยม",
"ขอบคุณ", "ขอบคุณมาก", "ขอบคุณครับ", "ขอบคุณค่ะ",
"สวย", "สวยจัง", "สวยมาก", "เรียบหรู", "หรู", "สมกับ", "เหมาะกับ",
"น่ารัก", "ใจดี", "อบอุ่น",
}
CARE_CUES = {
"เดี๋ยวจะไม่สบาย", "ระวังจะไม่สบาย", "ระวังเป็นหวัด", "พักผ่อน", "ดูแลตัวเอง",
"หนาวจัง", "เอาเสื้อคลุม", "คลุมไหล่", "ห่ม", "ดื่มน้ำอุ่น",
}
TOKENIZER = None
MODEL = None
DEVICE = None
def get_model():
global TOKENIZER, MODEL, DEVICE
if MODEL is not None and _TOKENIZER is not None and DEVICE is not None:
return TOKENIZER, MODEL, DEVICE
try:
import torch
from transformers import AutoModelForSequenceClassification, PreTrainedTokenizerFast
except Exception as e:
raise RuntimeError("Missing deps. Install: pip install -U transformers torch") from e
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TOKENIZER = PreTrainedTokenizerFast.from_pretrained(HF_MODEL_NAME)
MODEL = AutoModelForSequenceClassification.from_pretrained(HF_MODEL_NAME).to(DEVICE)
MODEL.eval()
return TOKENIZER, MODEL, DEVICE
def toxicity_gate(text: str) -> Optional[str]:
t = (text or "").strip()
if not t:
return None
for w in TOXIC_WORDS:
if w in t:
return "Negative"
return None
def positive_gate(text: str) -> Optional[str]:
t = (text or "").strip()
if not t:
return None
for w in POSITIVE_CUES:
if w in t:
return "Positive"
for w in CARE_CUES:
if w in t:
return "Positive"
polite_offer = ("ไหมครับ" in t) or ("ไหมคะ" in t) or ("ไหมค่ะ" in t) or ("เดี๋ยว" in t)
if polite_offer and ("ช่วย" in t or "เอา" in t or "ให้" in t or "ไว้" in t):
return "Positive"
return None
def predict_with_probs(text: str):
t = (text or "").strip()
if not t:
return "Neutral", {"Positive": 0.0, "Neutral": 1.0, "Negative": 0.0}
import torch
tokenizer, model, device = get_model()
inputs = tokenizer(
t,
return_tensors="pt",
padding=True,
truncation=True,
max_length=256,
)
inputs.pop("token_type_ids", None)
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
out = model(**inputs)
probs = torch.softmax(out.logits, dim=-1)[0].detach().cpu().tolist()
dist = {"Positive": float(probs[0]), "Neutral": float(probs[1]), "Negative": float(probs[2])}
label = max(dist, key=dist.get)
return label, dist
def analyze_sentiment_thai(text: str) -> str:
t = (text or "").strip()
if not t:
return "Neutral"
neg = toxicity_gate(t)
if neg:
return neg
pos = positive_gate(t)
if pos:
return pos
try:
_, dist = predict_with_probs(t)
top_label = max(dist, key=dist.get)
top_prob = dist[top_label]
if top_prob < CONF_THRESH:
return "Neutral"
if top_label == "Neutral":
pos_p = dist["Positive"]
neu_p = dist["Neutral"]
neg_p = dist["Negative"]
if neg_p >= neu_p - NEU_MARGIN and neg_p > pos_p:
return "Negative"
if pos_p >= neu_p - NEU_MARGIN and pos_p > neg_p:
return "Positive"
return "Neutral"
return top_label
except Exception:
return "Neutral"
def score_delta_from_label(label: str) -> int:
if label == "Positive":
return 2
if label == "Negative":
return -3
return 1
class SentimentNode:
def __init__(self):
pass
def run(self, state: OverallState) -> Dict[str, Any]:
task_id = state.get("task_id", "")
session_id = state.get("session_id", "")
question = state.get("question", "")
label = analyze_sentiment_thai(question)
score_delta = score_delta_from_label(label)
writer = get_stream_writer()
if writer:
writer(
{
"type": "sentiment",
"task_id": task_id,
"session_id": session_id,
"sentiment": label,
"score_delta": score_delta,
}
)
return {"sentiment": label, "score": int(score_delta)}
|