StoryboardBeats-Mini-0.1 / inference.py
maldons77's picture
Upload 6 files
093d678 verified
import joblib
from typing import List, Dict
ARTIFACT_PATH = "model.joblib"
_model = None
def load_model():
global _model
if _model is None:
_model = joblib.load(ARTIFACT_PATH)
return _model
def predict(inputs: List[str]) -> Dict:
"""
Inputs: list of prompts (strings)
Returns: dict with beats (multi-label) and style for each input
"""
model = load_model()
beats_clf = model["beats_model"]
style_clf = model["style_model"]
beats_classes = model["beats_classes"]
# Beats probabilities -> threshold 0.5
beats_proba = beats_clf.predict_proba(inputs)
beats_pred = (beats_proba >= 0.5).astype(int)
# Style label
style_pred = style_clf.predict(inputs)
results = []
for i, text in enumerate(inputs):
beats = [b for b, v in zip(beats_classes, beats_pred[i]) if v == 1]
results.append({
"input": text,
"beats": beats,
"style": str(style_pred[i]),
"beats_proba": {b: float(p) for b, p in zip(beats_classes, beats_proba[i])}
})
return {"results": results}