File size: 3,456 Bytes
bae0f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8d1fac5
 
 
 
 
 
bae0f63
 
8d1fac5
bae0f63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
text_emotion_engine.py — DistilBERT Multi-Label Text Emotion Classifier
Uses: bhadresh-savani/distilbert-base-uncased-emotion
Output: top-N emotions with calibrated confidence scores.
Runs inference in asyncio.to_thread to avoid blocking the event loop.
"""
from __future__ import annotations

import asyncio
import logging
from typing import List, Optional

from app.schemas import EmotionLabel

logger = logging.getLogger(__name__)

_pipeline = None
_load_error: Optional[str] = None


def _load_pipeline(model_name: str) -> None:
    """Called once at startup. Loads the HuggingFace pipeline into global."""
    global _pipeline, _load_error
    try:
        from transformers import pipeline as hf_pipeline
        import os
        
        # Determine local path
        local_path = os.path.join("app", "ml_assets", "distilbert_model")
        
        logger.info("Loading DistilBERT text emotion model from %s", local_path)
        _pipeline = hf_pipeline(
            "text-classification",
            model=local_path if os.path.exists(local_path) else model_name,
            top_k=None,           # Return ALL labels
            truncation=True,
            max_length=512,
        )
        logger.info("✅ DistilBERT emotion model loaded successfully.")
    except Exception as exc:
        _load_error = str(exc)
        logger.error("❌ Failed to load DistilBERT model: %s", exc)


def initialize(model_name: str) -> None:
    """Called at app startup to pre-warm the model."""
    _load_pipeline(model_name)


class TextEmotionEngine:
    """
    Wraps the HuggingFace DistilBERT pipeline for async use in FastAPI.
    """

    def _classify_sync(self, text: str) -> List[EmotionLabel]:
        if _pipeline is None:
            return []
        try:
            results = _pipeline(text[:512])
            if not results:
                return []
            # pipeline returns list-of-list when top_k=None
            raw = results[0] if isinstance(results[0], list) else results
            labels = [
                EmotionLabel(label=item["label"].lower(), score=round(item["score"], 4))
                for item in raw
            ]
            # Sort descending by score
            return sorted(labels, key=lambda x: x.score, reverse=True)
        except Exception as exc:
            logger.error("DistilBERT inference error: %s", exc)
            return []

    async def classify(self, text: str) -> List[EmotionLabel]:
        """
        Async wrapper — runs CPU-bound inference in a thread pool.
        Returns list of EmotionLabel sorted by confidence desc.
        """
        return await asyncio.to_thread(self._classify_sync, text)

    async def top_emotion(self, text: str) -> str:
        """Returns the single dominant emotion label."""
        labels = await self.classify(text)
        return labels[0].label if labels else "neutral"

    def summary_string(self, labels: List[EmotionLabel], top_k: int = 3) -> str:
        """
        Formats top-k labels as a string for LLM prompt injection.
        Example: "sadness(0.87), fear(0.08), anger(0.03)"
        """
        return ", ".join(
            f"{lbl.label}({lbl.score:.2f})" for lbl in labels[:top_k]
        )

    @property
    def is_loaded(self) -> bool:
        return _pipeline is not None

    @property
    def load_error(self) -> Optional[str]:
        return _load_error


# Singleton
text_emotion_engine = TextEmotionEngine()