|
|
import joblib |
|
|
from typing import List, Dict |
|
|
|
|
|
ARTIFACT_PATH = "model.joblib" |
|
|
_model = None |
|
|
|
|
|
def load_model(): |
|
|
global _model |
|
|
if _model is None: |
|
|
_model = joblib.load(ARTIFACT_PATH) |
|
|
return _model |
|
|
|
|
|
def predict(inputs: List[str]) -> Dict: |
|
|
""" |
|
|
Inputs: list of prompts (strings) |
|
|
Returns: dict with beats (multi-label) and style for each input |
|
|
""" |
|
|
model = load_model() |
|
|
beats_clf = model["beats_model"] |
|
|
style_clf = model["style_model"] |
|
|
beats_classes = model["beats_classes"] |
|
|
|
|
|
beats_proba = beats_clf.predict_proba(inputs) |
|
|
beats_pred = (beats_proba >= 0.5).astype(int) |
|
|
|
|
|
style_pred = style_clf.predict(inputs) |
|
|
results = [] |
|
|
for i, text in enumerate(inputs): |
|
|
beats = [b for b, v in zip(beats_classes, beats_pred[i]) if v == 1] |
|
|
results.append({ |
|
|
"input": text, |
|
|
"beats": beats, |
|
|
"style": str(style_pred[i]), |
|
|
"beats_proba": {b: float(p) for b, p in zip(beats_classes, beats_proba[i])} |
|
|
}) |
|
|
return {"results": results} |
|
|
|