File size: 1,104 Bytes
093d678
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import joblib
from typing import List, Dict

ARTIFACT_PATH = "model.joblib"
_model = None

def load_model():
    global _model
    if _model is None:
        _model = joblib.load(ARTIFACT_PATH)
    return _model

def predict(inputs: List[str]) -> Dict:
    """
    Inputs: list of prompts (strings)
    Returns: dict with beats (multi-label) and style for each input
    """
    model = load_model()
    beats_clf = model["beats_model"]
    style_clf = model["style_model"]
    beats_classes = model["beats_classes"]
    # Beats probabilities -> threshold 0.5
    beats_proba = beats_clf.predict_proba(inputs)
    beats_pred = (beats_proba >= 0.5).astype(int)
    # Style label
    style_pred = style_clf.predict(inputs)
    results = []
    for i, text in enumerate(inputs):
        beats = [b for b, v in zip(beats_classes, beats_pred[i]) if v == 1]
        results.append({
            "input": text,
            "beats": beats,
            "style": str(style_pred[i]),
            "beats_proba": {b: float(p) for b, p in zip(beats_classes, beats_proba[i])}
        })
    return {"results": results}