Spaces:
Running
Running
| """ | |
| Tribe V2 Brain Predictive Model β REST API Service | |
| Deployed on HuggingFace Spaces (CPU) | |
| """ | |
| import os | |
| import math | |
| import warnings | |
| import threading | |
| import traceback | |
| import numpy as np | |
| from flask import Flask, request, jsonify | |
| warnings.filterwarnings("ignore") | |
| app = Flask(__name__) | |
| # βββ Global model state ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| _model = None | |
| _model_loaded = False | |
| _model_loading = False | |
| _model_error = None | |
| _model_lock = threading.Lock() | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| CKPT = os.environ.get("TRIBE_CKPT", "facebook/tribev2") | |
| # βββ ROI vertex indices βββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| ROI_INDICES = { | |
| "language": list(range(0, 800)), | |
| "visual": list(range(800, 1600)), | |
| "attention": list(range(1600, 2400)), | |
| "emotion": list(range(2400, 3200)), | |
| "default": list(range(3200, 4000)), | |
| } | |
| def _load_model(): | |
| global _model, _model_loaded, _model_loading, _model_error | |
| with _model_lock: | |
| if _model_loaded or _model_loading: | |
| return | |
| _model_loading = True | |
| _model_error = None | |
| try: | |
| print("[Tribe V2] Starting model load...", flush=True) | |
| # Set HF token for gated model access (all methods) | |
| if HF_TOKEN: | |
| os.environ["HUGGING_FACE_HUB_TOKEN"] = HF_TOKEN | |
| os.environ["HF_TOKEN"] = HF_TOKEN | |
| # Also login via huggingface_hub so transformers picks it up | |
| try: | |
| from huggingface_hub import login | |
| login(token=HF_TOKEN, add_to_git_credential=False) | |
| print(f"[Tribe V2] HF login OK (len={len(HF_TOKEN)})", flush=True) | |
| except Exception as login_err: | |
| print(f"[Tribe V2] HF login warning: {login_err}", flush=True) | |
| else: | |
| print("[Tribe V2] WARNING: No HF_TOKEN set!", flush=True) | |
| import torch | |
| print(f"[Tribe V2] PyTorch {torch.__version__}", flush=True) | |
| # ββ CPU Patch ββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # The Tribe V2 config.yaml has device: cuda for the text extractor. | |
| # We patch HuggingFaceText._load_model to force device='cpu' via | |
| # object.__setattr__ (bypasses pydantic's immutability) BEFORE | |
| # the model weights are moved to device. | |
| try: | |
| from neuralset.extractors.text import HuggingFaceText | |
| orig_load = HuggingFaceText._load_model | |
| def cpu_patched_load(self): | |
| """Force CPU device before loading model weights.""" | |
| object.__setattr__(self, 'device', 'cpu') | |
| return orig_load(self) | |
| HuggingFaceText._load_model = cpu_patched_load | |
| print("[Tribe V2] CPU patch applied (object.__setattr__)", flush=True) | |
| except Exception as e: | |
| print(f"[Tribe V2] CPU patch warning: {e}", flush=True) | |
| from tribev2 import TribeModel | |
| print("[Tribe V2] Loading TribeModel...", flush=True) | |
| model = TribeModel.from_pretrained(CKPT, device='cpu') | |
| print("[Tribe V2] TribeModel loaded!", flush=True) | |
| with _model_lock: | |
| _model = model | |
| _model_loaded = True | |
| _model_loading = False | |
| print("[Tribe V2] Ready!", flush=True) | |
| except Exception as e: | |
| err = traceback.format_exc() | |
| print(f"[Tribe V2] LOAD ERROR:\n{err}", flush=True) | |
| with _model_lock: | |
| _model_loading = False | |
| _model_error = str(e) | |
| def _normalize_roi(val: float, overall: float) -> float: | |
| if overall == 0: | |
| return 50.0 | |
| ratio = val / overall | |
| score = 50 + 50 * math.tanh((ratio - 1.0) * 2) | |
| return round(min(100.0, max(0.0, score)), 1) | |
| def _score_text(text: str) -> dict: | |
| import pandas as pd | |
| words = [w for w in text.split() if w.strip()] | |
| if not words: | |
| raise ValueError("Empty text") | |
| duration = 0.5 | |
| # neuralset Word-Felder: start (float), timeline (str), text, context, duration | |
| rows = [{"type": "Word", | |
| "text": w, | |
| "context": text, | |
| "start": i * duration, | |
| "duration": duration, | |
| "timeline": "default"} | |
| for i, w in enumerate(words)] | |
| events_df = pd.DataFrame(rows) | |
| # TribeModel.predict() gibt (preds_array, segments) als Tuple zurΓΌck | |
| result = _model.predict(events_df, verbose=False) | |
| pred_array = result[0] if isinstance(result, tuple) else result | |
| if hasattr(pred_array, "numpy"): | |
| pred_array = pred_array.numpy() | |
| pred_array = np.array(pred_array) | |
| if pred_array.ndim == 1: | |
| pred_array = pred_array.reshape(1, -1) | |
| n_timesteps, n_vertices = pred_array.shape | |
| overall_mean = float(np.abs(pred_array).mean()) | |
| def roi_score(indices): | |
| safe = [i for i in indices if i < n_vertices] | |
| return float(np.abs(pred_array[:, safe]).mean()) if safe else overall_mean | |
| lang = roi_score(ROI_INDICES["language"]) | |
| vis = roi_score(ROI_INDICES["visual"]) | |
| att = roi_score(ROI_INDICES["attention"]) | |
| emo = roi_score(ROI_INDICES["emotion"]) | |
| dmn = roi_score(ROI_INDICES["default"]) | |
| lp = _normalize_roi(lang, overall_mean) | |
| vi = _normalize_roi(vis, overall_mean) | |
| ac = _normalize_roi(att, overall_mean) | |
| ev = _normalize_roi(emo, overall_mean) | |
| be = _normalize_roi(dmn, overall_mean) | |
| viral = round(be*0.35 + ev*0.30 + ac*0.20 + lp*0.10 + vi*0.05, 1) | |
| timeline = [round(float(np.abs(pred_array[t]).mean()), 4) for t in range(n_timesteps)] | |
| mid = n_vertices // 2 | |
| dom = "links" if float(np.abs(pred_array[:, :mid]).mean()) >= float(np.abs(pred_array[:, mid:]).mean()) else "rechts" | |
| return { | |
| "language_processing": lp, "visual_imagery": vi, | |
| "attention_capture": ac, "emotional_valence": ev, | |
| "overall_brain_engagement": be, "viral_potential": viral, | |
| "activation_timeline": timeline, "n_timesteps": n_timesteps, | |
| "n_vertices": n_vertices, "dominant_hemisphere": dom, | |
| "word_count": len(words), "overall_mean_activation": round(overall_mean, 6), | |
| } | |
| # βββ Routes ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def index(): | |
| return jsonify({"service": "tribe-v2-api", "status": "ok", | |
| "endpoints": ["/health", "/warmup", "/predict"]}) | |
| def health(): | |
| return jsonify({ | |
| "status": "ok" if _model_loaded else ("loading" if _model_loading else "offline"), | |
| "model": "tribe-v2", | |
| "model_loaded": _model_loaded, | |
| "model_loading": _model_loading, | |
| "error": _model_error, | |
| }) | |
| def warmup(): | |
| if not _model_loaded and not _model_loading: | |
| threading.Thread(target=_load_model, daemon=True).start() | |
| return jsonify({"status": "warming_up", "model_loaded": _model_loaded, | |
| "model_loading": _model_loading}) | |
| def predict(): | |
| if not _model_loaded: | |
| return jsonify({"error": "Model not loaded. Call /warmup first.", | |
| "model_loading": _model_loading, | |
| "load_error": _model_error}), 503 | |
| data = request.get_json(force=True, silent=True) or {} | |
| text = (data.get("text") or "").strip()[:5000] | |
| if len(text) < 5: | |
| return jsonify({"error": "Text too short (min 5 chars)"}), 400 | |
| try: | |
| return jsonify({"scores": _score_text(text), "status": "ok"}) | |
| except Exception as e: | |
| return jsonify({"error": str(e), "trace": traceback.format_exc()}), 500 | |
| # βββ Entry point βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| if __name__ == "__main__": | |
| # Start model loading immediately on startup | |
| threading.Thread(target=_load_model, daemon=True).start() | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host="0.0.0.0", port=port, debug=False) | |