""" Tribe V2 Brain Predictive Model – REST API Service Deployed on HuggingFace Spaces (CPU) """ import os import math import warnings import threading import traceback import numpy as np from flask import Flask, request, jsonify warnings.filterwarnings("ignore") app = Flask(__name__) # ─── Global model state ────────────────────────────────────────────────────── _model = None _model_loaded = False _model_loading = False _model_error = None _model_lock = threading.Lock() HF_TOKEN = os.environ.get("HF_TOKEN", "") CKPT = os.environ.get("TRIBE_CKPT", "facebook/tribev2") # ─── ROI vertex indices ─────────────────────────────────────────────────────── ROI_INDICES = { "language": list(range(0, 800)), "visual": list(range(800, 1600)), "attention": list(range(1600, 2400)), "emotion": list(range(2400, 3200)), "default": list(range(3200, 4000)), } def _load_model(): global _model, _model_loaded, _model_loading, _model_error with _model_lock: if _model_loaded or _model_loading: return _model_loading = True _model_error = None try: print("[Tribe V2] Starting model load...", flush=True) # Set HF token for gated model access (all methods) if HF_TOKEN: os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN os.environ["HF_TOKEN"] = HF_TOKEN # Also login via huggingface_hub so transformers picks it up try: from huggingface_hub import login login(token=HF_TOKEN, add_to_git_credential=False) print(f"[Tribe V2] HF login OK (len={len(HF_TOKEN)})", flush=True) except Exception as login_err: print(f"[Tribe V2] HF login warning: {login_err}", flush=True) else: print("[Tribe V2] WARNING: No HF_TOKEN set!", flush=True) import torch print(f"[Tribe V2] PyTorch {torch.__version__}", flush=True) # ── CPU Patch ────────────────────────────────────────────────────── # The Tribe V2 config.yaml has device: cuda for the text extractor. # We patch HuggingFaceText._load_model to force device='cpu' via # object.__setattr__ (bypasses pydantic's immutability) BEFORE # the model weights are moved to device. try: from neuralset.extractors.text import HuggingFaceText orig_load = HuggingFaceText._load_model def cpu_patched_load(self): """Force CPU device before loading model weights.""" object.__setattr__(self, 'device', 'cpu') return orig_load(self) HuggingFaceText._load_model = cpu_patched_load print("[Tribe V2] CPU patch applied (object.__setattr__)", flush=True) except Exception as e: print(f"[Tribe V2] CPU patch warning: {e}", flush=True) from tribev2 import TribeModel print("[Tribe V2] Loading TribeModel...", flush=True) model = TribeModel.from_pretrained(CKPT, device='cpu') print("[Tribe V2] TribeModel loaded!", flush=True) with _model_lock: _model = model _model_loaded = True _model_loading = False print("[Tribe V2] Ready!", flush=True) except Exception as e: err = traceback.format_exc() print(f"[Tribe V2] LOAD ERROR:\n{err}", flush=True) with _model_lock: _model_loading = False _model_error = str(e) def _normalize_roi(val: float, overall: float) -> float: if overall == 0: return 50.0 ratio = val / overall score = 50 + 50 * math.tanh((ratio - 1.0) * 2) return round(min(100.0, max(0.0, score)), 1) def _score_text(text: str) -> dict: import pandas as pd words = [w for w in text.split() if w.strip()] if not words: raise ValueError("Empty text") duration = 0.5 # neuralset Word-Felder: start (float), timeline (str), text, context, duration rows = [{"type": "Word", "text": w, "context": text, "start": i * duration, "duration": duration, "timeline": "default"} for i, w in enumerate(words)] events_df = pd.DataFrame(rows) # TribeModel.predict() gibt (preds_array, segments) als Tuple zurück result = _model.predict(events_df, verbose=False) pred_array = result[0] if isinstance(result, tuple) else result if hasattr(pred_array, "numpy"): pred_array = pred_array.numpy() pred_array = np.array(pred_array) if pred_array.ndim == 1: pred_array = pred_array.reshape(1, -1) n_timesteps, n_vertices = pred_array.shape overall_mean = float(np.abs(pred_array).mean()) def roi_score(indices): safe = [i for i in indices if i < n_vertices] return float(np.abs(pred_array[:, safe]).mean()) if safe else overall_mean lang = roi_score(ROI_INDICES["language"]) vis = roi_score(ROI_INDICES["visual"]) att = roi_score(ROI_INDICES["attention"]) emo = roi_score(ROI_INDICES["emotion"]) dmn = roi_score(ROI_INDICES["default"]) lp = _normalize_roi(lang, overall_mean) vi = _normalize_roi(vis, overall_mean) ac = _normalize_roi(att, overall_mean) ev = _normalize_roi(emo, overall_mean) be = _normalize_roi(dmn, overall_mean) viral = round(be*0.35 + ev*0.30 + ac*0.20 + lp*0.10 + vi*0.05, 1) timeline = [round(float(np.abs(pred_array[t]).mean()), 4) for t in range(n_timesteps)] mid = n_vertices // 2 dom = "links" if float(np.abs(pred_array[:, :mid]).mean()) >= float(np.abs(pred_array[:, mid:]).mean()) else "rechts" return { "language_processing": lp, "visual_imagery": vi, "attention_capture": ac, "emotional_valence": ev, "overall_brain_engagement": be, "viral_potential": viral, "activation_timeline": timeline, "n_timesteps": n_timesteps, "n_vertices": n_vertices, "dominant_hemisphere": dom, "word_count": len(words), "overall_mean_activation": round(overall_mean, 6), } # ─── Routes ────────────────────────────────────────────────────────────────── @app.route("/", methods=["GET"]) def index(): return jsonify({"service": "tribe-v2-api", "status": "ok", "endpoints": ["/health", "/warmup", "/predict"]}) @app.route("/health", methods=["GET"]) def health(): return jsonify({ "status": "ok" if _model_loaded else ("loading" if _model_loading else "offline"), "model": "tribe-v2", "model_loaded": _model_loaded, "model_loading": _model_loading, "error": _model_error, }) @app.route("/warmup", methods=["POST"]) def warmup(): if not _model_loaded and not _model_loading: threading.Thread(target=_load_model, daemon=True).start() return jsonify({"status": "warming_up", "model_loaded": _model_loaded, "model_loading": _model_loading}) @app.route("/predict", methods=["POST"]) def predict(): if not _model_loaded: return jsonify({"error": "Model not loaded. Call /warmup first.", "model_loading": _model_loading, "load_error": _model_error}), 503 data = request.get_json(force=True, silent=True) or {} text = (data.get("text") or "").strip()[:5000] if len(text) < 5: return jsonify({"error": "Text too short (min 5 chars)"}), 400 try: return jsonify({"scores": _score_text(text), "status": "ok"}) except Exception as e: return jsonify({"error": str(e), "trace": traceback.format_exc()}), 500 # ─── Entry point ───────────────────────────────────────────────────────────── if __name__ == "__main__": # Start model loading immediately on startup threading.Thread(target=_load_model, daemon=True).start() port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port, debug=False)