File size: 2,306 Bytes
3c0333d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
"""
Modele NER medical base sur GLiNER-BioMed.
Detecte les concepts medicaux dans du texte clinique en zero-shot.
"""

import os
import logging
from typing import List, Dict, Any

from gliner import GLiNER

logger = logging.getLogger(__name__)

# Labels d entites a detecter.
# GLiNER est zero-shot : vous pouvez modifier cette liste librement.
DEFAULT_LABELS = [
    "Maladie",
    "Symptome",
    "Medicament",
    "Procedure medicale",
    "Partie du corps",
    "Examen de laboratoire",
]

# Seuil de confiance minimum (0.0 - 1.0)
DEFAULT_THRESHOLD = float(os.environ.get("NER_THRESHOLD", "0.4"))

# Modele a charger (variants disponibles : small, base, large)
MODEL_NAME = os.environ.get(
    "GLINER_MODEL", "Ihor/gliner-biomed-small-v1.0"
)


class MedicalNERModel:
    """Wrapper autour de GLiNER-BioMed pour la detection d entites medicales."""

    def __init__(self):
        logger.info("Chargement de %s ...", MODEL_NAME)
        self.model = GLiNER.from_pretrained(MODEL_NAME)
        self.labels = self._load_labels()
        self.threshold = DEFAULT_THRESHOLD
        logger.info("Modele charge. Labels: %s | Seuil: %.2f", self.labels, self.threshold)

    @staticmethod
    def _load_labels() -> List[str]:
        env_labels = os.environ.get("NER_LABELS", "")
        if env_labels:
            return [lb.strip() for lb in env_labels.split(",") if lb.strip()]
        return DEFAULT_LABELS

    def predict(
        self,
        text: str,
        labels: List[str] = None,
        threshold: float = None,
    ) -> List[Dict[str, Any]]:
        """
        Detecte les entites medicales dans le texte.

        Returns:
            Liste de dicts avec keys: start, end, text, label, score.
        """
        if not text or not text.strip():
            return []

        use_labels = labels or self.labels
        use_threshold = threshold if threshold is not None else self.threshold

        entities = self.model.predict_entities(
            text, use_labels, threshold=use_threshold
        )

        return [
            {
                "start": ent["start"],
                "end": ent["end"],
                "text": ent["text"],
                "label": ent["label"],
                "score": float(ent["score"]),
            }
            for ent in entities
        ]