Spaces:
Sleeping
Sleeping
| from typing import Dict, Any, Optional | |
| from langgraph_config.langgraph_state import OverallState | |
| from langgraph.config import get_stream_writer | |
| HF_MODEL_NAME = "FlukeTJ/distilbert-base-thai-sentiment" | |
| CONF_THRESH = 0.50 | |
| NEU_MARGIN = 0.07 | |
| TOXIC_WORDS = { | |
| "โง่", "ควาย", "ห่วย", "มั่ว", "ไร้สาระ", "เลิกพูดมาก", "อย่ามากวน", | |
| "หน้าบูด", "ป้า", "ส้นตีน", "เหี้ย", "สัส", "ควย", "ปัญญาอ่อน", "ทุเรศ", "ขยะ", | |
| } | |
| POSITIVE_CUES = { | |
| "นับถือ", "ชื่นชม", "ภูมิใจ", "สุดยอด", "เก่งมาก", "ดีมาก", "ยอดเยี่ยม", | |
| "ขอบคุณ", "ขอบคุณมาก", "ขอบคุณครับ", "ขอบคุณค่ะ", | |
| "สวย", "สวยจัง", "สวยมาก", "เรียบหรู", "หรู", "สมกับ", "เหมาะกับ", | |
| "น่ารัก", "ใจดี", "อบอุ่น", | |
| } | |
| CARE_CUES = { | |
| "เดี๋ยวจะไม่สบาย", "ระวังจะไม่สบาย", "ระวังเป็นหวัด", "พักผ่อน", "ดูแลตัวเอง", | |
| "หนาวจัง", "เอาเสื้อคลุม", "คลุมไหล่", "ห่ม", "ดื่มน้ำอุ่น", | |
| } | |
| TOKENIZER = None | |
| MODEL = None | |
| DEVICE = None | |
| def get_model(): | |
| global TOKENIZER, MODEL, DEVICE | |
| if MODEL is not None and _TOKENIZER is not None and DEVICE is not None: | |
| return TOKENIZER, MODEL, DEVICE | |
| try: | |
| import torch | |
| from transformers import AutoModelForSequenceClassification, PreTrainedTokenizerFast | |
| except Exception as e: | |
| raise RuntimeError("Missing deps. Install: pip install -U transformers torch") from e | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| TOKENIZER = PreTrainedTokenizerFast.from_pretrained(HF_MODEL_NAME) | |
| MODEL = AutoModelForSequenceClassification.from_pretrained(HF_MODEL_NAME).to(DEVICE) | |
| MODEL.eval() | |
| return TOKENIZER, MODEL, DEVICE | |
| def toxicity_gate(text: str) -> Optional[str]: | |
| t = (text or "").strip() | |
| if not t: | |
| return None | |
| for w in TOXIC_WORDS: | |
| if w in t: | |
| return "Negative" | |
| return None | |
| def positive_gate(text: str) -> Optional[str]: | |
| t = (text or "").strip() | |
| if not t: | |
| return None | |
| for w in POSITIVE_CUES: | |
| if w in t: | |
| return "Positive" | |
| for w in CARE_CUES: | |
| if w in t: | |
| return "Positive" | |
| polite_offer = ("ไหมครับ" in t) or ("ไหมคะ" in t) or ("ไหมค่ะ" in t) or ("เดี๋ยว" in t) | |
| if polite_offer and ("ช่วย" in t or "เอา" in t or "ให้" in t or "ไว้" in t): | |
| return "Positive" | |
| return None | |
| def predict_with_probs(text: str): | |
| t = (text or "").strip() | |
| if not t: | |
| return "Neutral", {"Positive": 0.0, "Neutral": 1.0, "Negative": 0.0} | |
| import torch | |
| tokenizer, model, device = get_model() | |
| inputs = tokenizer( | |
| t, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=256, | |
| ) | |
| inputs.pop("token_type_ids", None) | |
| inputs = {k: v.to(device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| out = model(**inputs) | |
| probs = torch.softmax(out.logits, dim=-1)[0].detach().cpu().tolist() | |
| dist = {"Positive": float(probs[0]), "Neutral": float(probs[1]), "Negative": float(probs[2])} | |
| label = max(dist, key=dist.get) | |
| return label, dist | |
| def analyze_sentiment_thai(text: str) -> str: | |
| t = (text or "").strip() | |
| if not t: | |
| return "Neutral" | |
| neg = toxicity_gate(t) | |
| if neg: | |
| return neg | |
| pos = positive_gate(t) | |
| if pos: | |
| return pos | |
| try: | |
| _, dist = predict_with_probs(t) | |
| top_label = max(dist, key=dist.get) | |
| top_prob = dist[top_label] | |
| if top_prob < CONF_THRESH: | |
| return "Neutral" | |
| if top_label == "Neutral": | |
| pos_p = dist["Positive"] | |
| neu_p = dist["Neutral"] | |
| neg_p = dist["Negative"] | |
| if neg_p >= neu_p - NEU_MARGIN and neg_p > pos_p: | |
| return "Negative" | |
| if pos_p >= neu_p - NEU_MARGIN and pos_p > neg_p: | |
| return "Positive" | |
| return "Neutral" | |
| return top_label | |
| except Exception: | |
| return "Neutral" | |
| def score_delta_from_label(label: str) -> int: | |
| if label == "Positive": | |
| return 2 | |
| if label == "Negative": | |
| return -3 | |
| return 1 | |
| class SentimentNode: | |
| def __init__(self): | |
| pass | |
| def run(self, state: OverallState) -> Dict[str, Any]: | |
| task_id = state.get("task_id", "") | |
| session_id = state.get("session_id", "") | |
| question = state.get("question", "") | |
| label = analyze_sentiment_thai(question) | |
| score_delta = score_delta_from_label(label) | |
| writer = get_stream_writer() | |
| if writer: | |
| writer( | |
| { | |
| "type": "sentiment", | |
| "task_id": task_id, | |
| "session_id": session_id, | |
| "sentiment": label, | |
| "score_delta": score_delta, | |
| } | |
| ) | |
| return {"sentiment": label, "score": int(score_delta)} | |