File size: 2,873 Bytes
b7f3196
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import json
from dataclasses import asdict
from typing import List, Dict, Any, Optional

from sentence_transformers import SentenceTransformer
from classifier.head import ClassifierHead
from classifier.infer import predict_query
from classifier.utils import get_models
from retriever import Retriever
from team.candidates import get_candidates, _available
from config import settings

class HealthQueryPipeline:
    def __init__(self, use_reranker: bool = False):
        self.use_reranker = use_reranker
        self.embedding_model: Optional[SentenceTransformer] = None
        self.classifier: Optional[ClassifierHead] = None
        self.retriever: Optional[Retriever] = None
        self.is_initialized = False

    def initialize(self):
        """Loads models and initializes the retriever."""
        if self.is_initialized:
            return

        print(f"Loading embedding model: {settings.MODEL_NAME}...")
        self.embedding_model, self.classifier = get_models(model_id=settings.CLASSIFIER_NAME)
        print("Model loaded.")

        print("Initializing retriever...")
        cfg = _available(settings.CORPORA_CONFIG)
        if not cfg:
            raise RuntimeError("No corpora files found in data/corpora. Build them first.")

        self.retriever = Retriever(
            corpora_config=cfg,
            use_reranker=self.use_reranker,
            embedding_model=self.embedding_model
        )
        print("Retriever initialized.")
        self.is_initialized = True

    def predict(self, query: str, k: int = 10) -> Dict[str, Any]:
        """

        Runs the full pipeline: Classification -> Retrieval (if medical).

        """
        if not self.is_initialized:
            self.initialize()

        classification = predict_query(
            text=[query],
            embedding_model=self.embedding_model,
            classifier_head=self.classifier,
        )

        predictions = classification["prediction"]
        result = {
            "query": query,
            "classification": {
                "prediction": predictions[0],
                "probabilities": {
                    cat: prob
                    for cat, prob in zip(settings.CATEGORIES, classification['probabilities'])
                }
            },
            "retrieval": []
        }

        if "medical" in predictions:
            hits = get_candidates(
                query=query,
                retriever=self.retriever,
                k_retrieve=k,
            )
            result["retrieval"] = [asdict(hit) for hit in hits]

        return result

    def get_index_progress(self):
        """Returns (current, total) of the underlying index."""
        if not self.retriever:
            return 0, 0
        return self.retriever.get_index_progress()