tribe-v2-api / app.py
janrudolph's picture
Fix: Use object.__setattr__ to force CPU device in HuggingFaceText extractor
bd4b072
"""
Tribe V2 Brain Predictive Model – REST API Service
Deployed on HuggingFace Spaces (CPU)
"""
import os
import math
import warnings
import threading
import traceback
import numpy as np
from flask import Flask, request, jsonify
warnings.filterwarnings("ignore")
app = Flask(__name__)
# ─── Global model state ──────────────────────────────────────────────────────
_model = None
_model_loaded = False
_model_loading = False
_model_error = None
_model_lock = threading.Lock()
HF_TOKEN = os.environ.get("HF_TOKEN", "")
CKPT = os.environ.get("TRIBE_CKPT", "facebook/tribev2")
# ─── ROI vertex indices ───────────────────────────────────────────────────────
ROI_INDICES = {
"language": list(range(0, 800)),
"visual": list(range(800, 1600)),
"attention": list(range(1600, 2400)),
"emotion": list(range(2400, 3200)),
"default": list(range(3200, 4000)),
}
def _load_model():
global _model, _model_loaded, _model_loading, _model_error
with _model_lock:
if _model_loaded or _model_loading:
return
_model_loading = True
_model_error = None
try:
print("[Tribe V2] Starting model load...", flush=True)
# Set HF token for gated model access (all methods)
if HF_TOKEN:
os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN
os.environ["HF_TOKEN"] = HF_TOKEN
# Also login via huggingface_hub so transformers picks it up
try:
from huggingface_hub import login
login(token=HF_TOKEN, add_to_git_credential=False)
print(f"[Tribe V2] HF login OK (len={len(HF_TOKEN)})", flush=True)
except Exception as login_err:
print(f"[Tribe V2] HF login warning: {login_err}", flush=True)
else:
print("[Tribe V2] WARNING: No HF_TOKEN set!", flush=True)
import torch
print(f"[Tribe V2] PyTorch {torch.__version__}", flush=True)
# ── CPU Patch ──────────────────────────────────────────────────────
# The Tribe V2 config.yaml has device: cuda for the text extractor.
# We patch HuggingFaceText._load_model to force device='cpu' via
# object.__setattr__ (bypasses pydantic's immutability) BEFORE
# the model weights are moved to device.
try:
from neuralset.extractors.text import HuggingFaceText
orig_load = HuggingFaceText._load_model
def cpu_patched_load(self):
"""Force CPU device before loading model weights."""
object.__setattr__(self, 'device', 'cpu')
return orig_load(self)
HuggingFaceText._load_model = cpu_patched_load
print("[Tribe V2] CPU patch applied (object.__setattr__)", flush=True)
except Exception as e:
print(f"[Tribe V2] CPU patch warning: {e}", flush=True)
from tribev2 import TribeModel
print("[Tribe V2] Loading TribeModel...", flush=True)
model = TribeModel.from_pretrained(CKPT, device='cpu')
print("[Tribe V2] TribeModel loaded!", flush=True)
with _model_lock:
_model = model
_model_loaded = True
_model_loading = False
print("[Tribe V2] Ready!", flush=True)
except Exception as e:
err = traceback.format_exc()
print(f"[Tribe V2] LOAD ERROR:\n{err}", flush=True)
with _model_lock:
_model_loading = False
_model_error = str(e)
def _normalize_roi(val: float, overall: float) -> float:
if overall == 0:
return 50.0
ratio = val / overall
score = 50 + 50 * math.tanh((ratio - 1.0) * 2)
return round(min(100.0, max(0.0, score)), 1)
def _score_text(text: str) -> dict:
import pandas as pd
words = [w for w in text.split() if w.strip()]
if not words:
raise ValueError("Empty text")
duration = 0.5
# neuralset Word-Felder: start (float), timeline (str), text, context, duration
rows = [{"type": "Word",
"text": w,
"context": text,
"start": i * duration,
"duration": duration,
"timeline": "default"}
for i, w in enumerate(words)]
events_df = pd.DataFrame(rows)
# TribeModel.predict() gibt (preds_array, segments) als Tuple zurΓΌck
result = _model.predict(events_df, verbose=False)
pred_array = result[0] if isinstance(result, tuple) else result
if hasattr(pred_array, "numpy"):
pred_array = pred_array.numpy()
pred_array = np.array(pred_array)
if pred_array.ndim == 1:
pred_array = pred_array.reshape(1, -1)
n_timesteps, n_vertices = pred_array.shape
overall_mean = float(np.abs(pred_array).mean())
def roi_score(indices):
safe = [i for i in indices if i < n_vertices]
return float(np.abs(pred_array[:, safe]).mean()) if safe else overall_mean
lang = roi_score(ROI_INDICES["language"])
vis = roi_score(ROI_INDICES["visual"])
att = roi_score(ROI_INDICES["attention"])
emo = roi_score(ROI_INDICES["emotion"])
dmn = roi_score(ROI_INDICES["default"])
lp = _normalize_roi(lang, overall_mean)
vi = _normalize_roi(vis, overall_mean)
ac = _normalize_roi(att, overall_mean)
ev = _normalize_roi(emo, overall_mean)
be = _normalize_roi(dmn, overall_mean)
viral = round(be*0.35 + ev*0.30 + ac*0.20 + lp*0.10 + vi*0.05, 1)
timeline = [round(float(np.abs(pred_array[t]).mean()), 4) for t in range(n_timesteps)]
mid = n_vertices // 2
dom = "links" if float(np.abs(pred_array[:, :mid]).mean()) >= float(np.abs(pred_array[:, mid:]).mean()) else "rechts"
return {
"language_processing": lp, "visual_imagery": vi,
"attention_capture": ac, "emotional_valence": ev,
"overall_brain_engagement": be, "viral_potential": viral,
"activation_timeline": timeline, "n_timesteps": n_timesteps,
"n_vertices": n_vertices, "dominant_hemisphere": dom,
"word_count": len(words), "overall_mean_activation": round(overall_mean, 6),
}
# ─── Routes ──────────────────────────────────────────────────────────────────
@app.route("/", methods=["GET"])
def index():
return jsonify({"service": "tribe-v2-api", "status": "ok",
"endpoints": ["/health", "/warmup", "/predict"]})
@app.route("/health", methods=["GET"])
def health():
return jsonify({
"status": "ok" if _model_loaded else ("loading" if _model_loading else "offline"),
"model": "tribe-v2",
"model_loaded": _model_loaded,
"model_loading": _model_loading,
"error": _model_error,
})
@app.route("/warmup", methods=["POST"])
def warmup():
if not _model_loaded and not _model_loading:
threading.Thread(target=_load_model, daemon=True).start()
return jsonify({"status": "warming_up", "model_loaded": _model_loaded,
"model_loading": _model_loading})
@app.route("/predict", methods=["POST"])
def predict():
if not _model_loaded:
return jsonify({"error": "Model not loaded. Call /warmup first.",
"model_loading": _model_loading,
"load_error": _model_error}), 503
data = request.get_json(force=True, silent=True) or {}
text = (data.get("text") or "").strip()[:5000]
if len(text) < 5:
return jsonify({"error": "Text too short (min 5 chars)"}), 400
try:
return jsonify({"scores": _score_text(text), "status": "ok"})
except Exception as e:
return jsonify({"error": str(e), "trace": traceback.format_exc()}), 500
# ─── Entry point ─────────────────────────────────────────────────────────────
if __name__ == "__main__":
# Start model loading immediately on startup
threading.Thread(target=_load_model, daemon=True).start()
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port, debug=False)