somebeast commited on
Commit
9fe0a20
·
verified ·
1 Parent(s): a680bf0

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +250 -217
app.py CHANGED
@@ -1,12 +1,8 @@
1
  """TRIBE V2 — Brain Response Prediction (Meta)
2
- Predicts fMRI brain responses using Meta's TRIBE V2 model on ZeroGPU.
3
- """
4
- import subprocess, sys
5
 
6
- # Install tribev2 with --no-deps to bypass Python 3.11 requirement.
7
- # The actual code works on 3.10 the pyproject.toml constraint is conservative.
8
- # We install sub-deps separately to avoid version conflicts with ZeroGPU base image.
9
- # No pip installs at runtime — use only what's in requirements.txt + base image
10
 
11
  import gradio as gr
12
  import spaces
@@ -15,102 +11,111 @@ import numpy as np
15
  import os
16
  import json
17
  import tempfile
 
18
 
19
- # ---- Model ----
 
 
 
 
 
 
 
 
 
 
20
  model = None
21
 
22
  def ensure_model():
23
- """Load LLaMA 3.2-3B for text encoding. TRIBE v2's full pipeline requires
24
- Python 3.11, so we use the text encoder directly for brain-region scoring."""
25
  global model
26
  if model is not None:
27
  return model
28
- # Use a non-gated model for text encoding (LLaMA 3.2 is gated, fails on ZeroGPU workers)
29
- # microsoft/phi-2 is a strong 2.7B model that's fully open
30
- print("Loading text encoder (Phi-2 2.7B — open, no gating)...")
31
- from transformers import AutoModelForCausalLM, AutoTokenizer
32
- model_id = "microsoft/phi-2"
33
- model = {
34
- "tokenizer": AutoTokenizer.from_pretrained(model_id, trust_remote_code=True),
35
- "model": AutoModelForCausalLM.from_pretrained(
36
- model_id, torch_dtype=torch.float16,
37
- output_hidden_states=True, trust_remote_code=True,
38
- ),
39
- }
40
- print("Phi-2 loaded.")
41
  return model
42
 
43
- # Authenticate with HF to access gated models (LLaMA 3.2-3B)
44
- try:
45
- from huggingface_hub import login
46
- hf_token = os.environ.get("HF_TOKEN", "")
47
- if hf_token:
48
- login(token=hf_token, add_to_git_credential=False)
49
- print("HF authenticated.")
50
- else:
51
- print("WARNING: No HF_TOKEN set. Gated models (LLaMA) may fail to download.")
52
- except Exception as e:
53
- print(f"HF login warning: {e}")
54
-
55
- print("TRIBE V2 ready. Model loads on first GPU call.")
56
 
57
  # ---- ROI Mapping ----
58
  REGIONS = {
59
- "attention": ["S_intrapariet", "G_front_middle", "S_front_sup", "G_pariet_inf-Supramar"],
60
- "emotion": ["G_insular", "S_circular_insula", "G_cingul", "G_front_inf-Orbital"],
61
- "language": ["G_front_inf-Opercular", "G_front_inf-Triangul", "G_temp_sup-Lateral"],
62
- "visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine", "Pole_occipital"],
63
- "default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post"],
 
 
 
 
 
 
 
64
  }
65
- _roi = {"labels": None, "names": None}
 
66
 
67
  def _load_roi():
68
- if _roi["labels"] is not None:
69
  return _roi["labels"], _roi["names"]
70
  try:
71
  from nilearn import datasets
72
  d = datasets.fetch_atlas_surf_destrieux()
73
  _roi["labels"] = np.concatenate([d["labels_lh"], d["labels_rh"]])
74
  _roi["names"] = [n.decode() if isinstance(n, bytes) else str(n) for n in d["label_names"]]
75
- except Exception:
76
- pass
 
77
  return _roi["labels"], _roi["names"]
78
 
79
- def _sig(val, c=0.008, s=300.0):
80
- """Sigmoid normalization calibrated for projected hidden state magnitudes.
81
- Projected activations are very small (~0.001-0.02), so center and scale
82
- are tuned accordingly to map into 20-80 range."""
83
- return float(100.0 / (1.0 + np.exp(-s * (val - c))))
84
 
85
- def interpret(preds):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  if isinstance(preds, torch.Tensor):
87
  preds = preds.cpu().numpy()
88
  if preds.ndim == 1:
89
  preds = preds.reshape(1, -1)
 
90
  n_t, n_v = preds.shape
91
  labels, names = _load_roi()
92
 
93
- scores = {}
94
- for key, subs in REGIONS.items():
95
- if labels is not None and labels.shape[0] == n_v:
96
- ids = [i for i, n in enumerate(names) if any(s in n for s in subs)]
97
- mask = np.isin(labels, ids)
98
- scores[key] = float(np.abs(preds[:, mask]).mean()) if mask.any() else float(np.abs(preds).mean())
99
  else:
100
- scores[key] = float(np.abs(preds).mean())
101
 
102
- att = _sig(scores["attention"])
103
- emo = _sig(scores["emotion"])
104
- lang = _sig(scores["language"])
105
- vis = _sig(scores["visual"])
106
- dm = _sig(scores["default_mode"])
107
  overall = (att + emo + lang + vis + dm) / 5.0
108
  viral = att * 0.4 + emo * 0.4 + vis * 0.2
109
 
110
  temporal = [_sig(float(np.abs(preds[t]).mean())) for t in range(n_t)]
111
- hook = np.mean(temporal[:2]) if len(temporal) >= 2 else overall
112
- body = np.mean(temporal[2:]) if len(temporal) > 2 else overall
113
  retention = min(body / max(hook, 1) * 100, 100)
 
 
114
 
115
  return {
116
  "scores": {
@@ -121,189 +126,217 @@ def interpret(preds):
121
  "language_processing": round(lang, 1),
122
  "visual_imagery": round(vis, 1),
123
  "hook_effectiveness": round(hook, 1),
124
- "retention_prediction": round(retention, 1),
125
  },
126
  "raw": {
127
- "n_timesteps": n_t, "n_vertices": n_v,
 
 
128
  "temporal_profile": [round(v, 1) for v in temporal],
129
- "region_raw": {k: round(v, 4) for k, v in scores.items()},
 
130
  },
131
  }
132
 
133
- # ---- GPU Functions ----
134
- @spaces.GPU(duration=120)
135
- def predict_text_gpu(text):
136
- """Use LLM perplexity + semantic features to predict brain engagement.
137
-
138
- Instead of random projections (which produce uniform scores), we measure:
139
- - Perplexity: surprising/novel content → higher attention
140
- - Token entropy: diverse vocabulary higher language processing
141
- - Sentiment strength: emotional words → higher emotional valence
142
- - Specificity: numbers, names, concrete nouns → higher visual imagery
143
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  m = ensure_model()
145
- tokenizer = m["tokenizer"]
146
- llm = m["model"].cuda().half()
147
-
148
- inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512).to("cuda")
149
-
150
- with torch.inference_mode():
151
- outputs = llm(**inputs)
152
-
153
- logits = outputs.logits # (1, seq_len, vocab_size)
154
- hidden = outputs.hidden_states[-1] # last layer (1, seq_len, hidden)
155
-
156
- # --- Feature extraction ---
157
- # 1. Perplexity (how surprising is the text?)
158
- shift_logits = logits[:, :-1, :].contiguous()
159
- shift_labels = inputs["input_ids"][:, 1:].contiguous()
160
- loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
161
- token_losses = loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
162
- perplexity = float(torch.exp(token_losses.mean()).cpu())
163
- # Normalize: perplexity 1-10 = boring, 10-50 = interesting, 50+ = very novel
164
- attention_raw = min(perplexity / 30.0, 1.0) # 0-1 scale
165
-
166
- # 2. Token entropy (vocabulary diversity)
167
- token_ids = inputs["input_ids"][0].cpu().numpy()
168
- unique_ratio = len(set(token_ids.tolist())) / max(len(token_ids), 1)
169
- language_raw = unique_ratio # 0-1
170
-
171
- # 3. Emotional intensity (variance in hidden states = more expressive)
172
- hidden_np = hidden.squeeze().cpu().float().numpy() # (seq_len, hidden)
173
- token_norms = np.linalg.norm(hidden_np, axis=1)
174
- emotion_raw = float(np.std(token_norms) / (np.mean(token_norms) + 1e-8))
175
-
176
- # 4. Visual/specificity (presence of numbers, caps, punctuation = concrete)
177
- text_lower = text.lower()
178
- has_numbers = sum(1 for c in text if c.isdigit()) / max(len(text), 1)
179
- has_caps = sum(1 for c in text if c.isupper()) / max(len(text), 1)
180
- urgency_words = sum(1 for w in ["now", "shock", "destroy", "change", "secret",
181
- "never", "always", "must", "urgent", "breaking", "exclusive", "free"] if w in text_lower)
182
- visual_raw = min((has_numbers * 10 + has_caps * 5 + urgency_words * 0.15), 1.0)
183
-
184
- # 5. Default mode (self-referential: I, me, my, you, your)
185
- personal_words = sum(1 for w in text_lower.split() if w in
186
- ["i", "me", "my", "you", "your", "we", "our", "myself"])
187
- dm_raw = min(personal_words / max(len(text_lower.split()), 1) * 5, 1.0)
188
-
189
- # --- Map to 0-100 scores ---
190
- def to_score(val, center=0.3, steepness=8.0):
191
- clamped = max(0.0, min(1.0, val))
192
- return float(100.0 / (1.0 + np.exp(-steepness * (clamped - center))))
193
-
194
- scores = {
195
- "attention": to_score(attention_raw, 0.25, 6.0),
196
- "emotion": to_score(emotion_raw, 0.15, 10.0),
197
- "language": to_score(language_raw, 0.5, 8.0),
198
- "visual": to_score(visual_raw, 0.2, 8.0),
199
- "default_mode": to_score(dm_raw, 0.2, 6.0),
200
- }
201
 
202
- overall = np.mean(list(scores.values()))
203
- viral = scores["attention"] * 0.4 + scores["emotion"] * 0.4 + scores["visual"] * 0.2
204
- hook_score = scores["attention"] # attention IS the hook
205
- retention = min(scores["language"] / max(scores["attention"], 1) * 100, 100)
206
 
 
 
 
 
 
207
  torch.cuda.empty_cache()
 
208
 
209
- return {
210
- "scores": {
211
- "overall_brain_engagement": round(overall, 1),
212
- "viral_potential": round(viral, 1),
213
- "attention_capture": round(scores["attention"], 1),
214
- "emotional_valence": round(scores["emotion"], 1),
215
- "language_processing": round(scores["language"], 1),
216
- "visual_imagery": round(scores["visual"], 1),
217
- "hook_effectiveness": round(hook_score, 1),
218
- "retention_prediction": round(retention, 1),
219
- },
220
- "raw": {
221
- "perplexity": round(perplexity, 2),
222
- "token_unique_ratio": round(unique_ratio, 3),
223
- "hidden_state_variance": round(emotion_raw, 4),
224
- "specificity": round(visual_raw, 3),
225
- "personal_reference": round(dm_raw, 3),
226
- },
227
- }
228
 
229
  # ---- Handlers ----
230
- def score_text(text):
 
231
  if not text or not text.strip():
232
- return "Please enter text.", ""
233
  try:
234
- r = predict_text_gpu(text.strip())
235
  s = r["scores"]
236
- lines = [
237
- f"Overall Engagement: {s['overall_brain_engagement']}/100",
238
- f"Viral Potential: {s['viral_potential']}/100",
239
- f"Attention Capture: {s['attention_capture']}/100",
240
- f"Emotional Valence: {s['emotional_valence']}/100",
241
- f"Language Processing: {s['language_processing']}/100",
242
- f"Visual Imagery: {s['visual_imagery']}/100",
243
- f"Hook Effectiveness: {s['hook_effectiveness']}/100",
244
- f"Retention: {s['retention_prediction']}/100",
245
- ]
246
- # Summary
247
- o = s["overall_brain_engagement"]
248
- summary = f"{'Strong' if o >= 70 else 'Decent' if o >= 50 else 'Weak'} engagement ({o}/100). "
249
- if s["attention_capture"] < 40:
250
- summary += "Needs stronger opening hook. "
251
- if s["emotional_valence"] >= 70:
252
- summary += "Great emotional trigger. "
253
- elif s["emotional_valence"] < 40:
254
- summary += "Add personal stakes or urgency. "
255
- if s["hook_effectiveness"] >= 70 and s["retention_prediction"] < 50:
256
- summary += "Good hook but drops off mid-section. "
257
- return "\n".join(lines), summary
258
  except Exception as e:
259
  import traceback
260
- return f"Error: {type(e).__name__}: {e}\n{traceback.format_exc()}", ""
261
 
262
- def score_json(text):
263
- if not text or not text.strip():
264
- return '{"error": "No text provided"}'
 
265
  try:
266
- r = predict_text_gpu(text.strip())
267
- return json.dumps(r, indent=2)
 
 
 
 
268
  except Exception as e:
269
- return json.dumps({"error": str(e)})
 
270
 
271
- def ab_test(a, b):
 
272
  if not a or not b:
273
  return "Enter both versions."
274
  try:
275
- ra = predict_text_gpu(a.strip())
276
- rb = predict_text_gpu(b.strip())
277
  sa, sb = ra["scores"], rb["scores"]
278
  va, vb = sa["viral_potential"], sb["viral_potential"]
279
- w = f"Version A wins ({va} vs {vb})" if va > vb else f"Version B wins ({vb} vs {va})" if vb > va else "Tie"
280
- return f"{w}\n\nA: engagement={sa['overall_brain_engagement']} viral={va} hook={sa['hook_effectiveness']}\nB: engagement={sb['overall_brain_engagement']} viral={vb} hook={sb['hook_effectiveness']}"
 
281
  except Exception as e:
282
  return f"Error: {e}"
283
 
284
- # ---- Gradio UI (Textbox only — avoids ZeroGPU schema bug) ----
285
- with gr.Blocks(title="TRIBE V2 Brain Prediction") as demo:
286
- gr.Markdown("# TRIBE V2 Brain Response Prediction\nMeta's fMRI model predicts brain engagement with your content.\n")
287
-
288
- with gr.Tab("Score Text"):
289
- inp = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...")
290
- btn = gr.Button("Analyze", variant="primary")
291
- out = gr.Textbox(label="Scores", lines=10)
292
- summary = gr.Textbox(label="Insight")
293
- btn.click(score_text, [inp], [out, summary], api_name="predict")
294
-
295
- with gr.Tab("A/B Test"):
296
- a = gr.Textbox(label="Version A", lines=3)
297
- b = gr.Textbox(label="Version B", lines=3)
298
- btn2 = gr.Button("Compare", variant="primary")
299
- res = gr.Textbox(label="Result", lines=6)
300
- btn2.click(ab_test, [a, b], [res], api_name="ab_test")
301
-
302
- with gr.Tab("API (JSON)"):
303
- gr.Markdown("Returns raw JSON for programmatic use.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
304
  api_in = gr.Textbox(label="Text", lines=3)
305
  api_btn = gr.Button("Get JSON")
306
- api_out = gr.Textbox(label="Response", lines=15)
307
- api_btn.click(score_json, [api_in], [api_out], api_name="api_predict")
 
 
 
308
 
309
  demo.queue().launch()
 
1
  """TRIBE V2 — Brain Response Prediction (Meta)
 
 
 
2
 
3
+ Full multimodal brain prediction using Meta's TRIBE V2 model.
4
+ Supports video, audio, and text scoring on ZeroGPU (Python 3.12).
5
+ """
 
6
 
7
  import gradio as gr
8
  import spaces
 
11
  import os
12
  import json
13
  import tempfile
14
+ import io
15
 
16
+ # ---- HF Auth for gated models (LLaMA 3.2-3B) ----
17
+ try:
18
+ from huggingface_hub import login
19
+ hf_token = os.environ.get("HF_TOKEN", "")
20
+ if hf_token:
21
+ login(token=hf_token, add_to_git_credential=False)
22
+ print("HF authenticated for gated model access.")
23
+ except Exception as e:
24
+ print(f"HF auth warning: {e}")
25
+
26
+ # ---- Model (loads on first GPU call) ----
27
  model = None
28
 
29
  def ensure_model():
 
 
30
  global model
31
  if model is not None:
32
  return model
33
+ print("Loading TRIBE V2 model...")
34
+ from tribev2 import TribeModel
35
+ model = TribeModel.from_pretrained("facebook/tribev2")
36
+ print(f"Model loaded: {type(model)}")
 
 
 
 
 
 
 
 
 
37
  return model
38
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # ---- ROI Mapping ----
41
  REGIONS = {
42
+ "attention": ["S_intrapariet", "G_front_middle", "S_front_sup",
43
+ "G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"],
44
+ "emotion": ["G_insular", "S_circular_insula", "G_cingul",
45
+ "G_front_inf-Orbital", "G_rectus", "G_subcallosal"],
46
+ "language": ["G_front_inf-Opercular", "G_front_inf-Triangul",
47
+ "G_temp_sup-Lateral", "G_temp_sup-Plan_tempo",
48
+ "S_temporal_sup", "G_and_S_subcentral"],
49
+ "visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine",
50
+ "Pole_occipital", "G_oc-temp_lat-fusifor",
51
+ "S_oc_sup_and_transversal", "G_oc-temp_med-Lingual"],
52
+ "default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post",
53
+ "G_temp_sup-Plan_polar", "G_parietal_sup"],
54
  }
55
+
56
+ _roi = {"labels": None, "names": None, "loaded": False}
57
 
58
  def _load_roi():
59
+ if _roi["loaded"]:
60
  return _roi["labels"], _roi["names"]
61
  try:
62
  from nilearn import datasets
63
  d = datasets.fetch_atlas_surf_destrieux()
64
  _roi["labels"] = np.concatenate([d["labels_lh"], d["labels_rh"]])
65
  _roi["names"] = [n.decode() if isinstance(n, bytes) else str(n) for n in d["label_names"]]
66
+ except Exception as e:
67
+ print(f"ROI atlas warning: {e}")
68
+ _roi["loaded"] = True
69
  return _roi["labels"], _roi["names"]
70
 
 
 
 
 
 
71
 
72
+ def _get_mask(labels, names, region_key):
73
+ if labels is None:
74
+ return None
75
+ subs = REGIONS.get(region_key, [])
76
+ ids = [i for i, n in enumerate(names) if any(s in n for s in subs)]
77
+ mask = np.isin(labels, ids)
78
+ return mask if mask.any() else None
79
+
80
+
81
+ def _sig(val, center=0.15, scale=20.0):
82
+ return float(100.0 / (1.0 + np.exp(-scale * (val - center))))
83
+
84
+
85
+ def interpret(preds, modalities=None):
86
+ """Convert (n_timesteps, n_vertices) cortical predictions to scores."""
87
+ if modalities is None:
88
+ modalities = ["text"]
89
  if isinstance(preds, torch.Tensor):
90
  preds = preds.cpu().numpy()
91
  if preds.ndim == 1:
92
  preds = preds.reshape(1, -1)
93
+
94
  n_t, n_v = preds.shape
95
  labels, names = _load_roi()
96
 
97
+ region_scores = {}
98
+ for key in REGIONS:
99
+ mask = _get_mask(labels, names, key)
100
+ if mask is not None and mask.shape[0] == n_v:
101
+ region_scores[key] = float(np.abs(preds[:, mask]).mean())
 
102
  else:
103
+ region_scores[key] = float(np.abs(preds).mean())
104
 
105
+ att = _sig(region_scores["attention"])
106
+ emo = _sig(region_scores["emotion"])
107
+ lang = _sig(region_scores["language"])
108
+ vis = _sig(region_scores["visual"])
109
+ dm = _sig(region_scores["default_mode"])
110
  overall = (att + emo + lang + vis + dm) / 5.0
111
  viral = att * 0.4 + emo * 0.4 + vis * 0.2
112
 
113
  temporal = [_sig(float(np.abs(preds[t]).mean())) for t in range(n_t)]
114
+ hook = float(np.mean(temporal[:2])) if len(temporal) >= 2 else overall
115
+ body = float(np.mean(temporal[2:])) if len(temporal) > 2 else overall
116
  retention = min(body / max(hook, 1) * 100, 100)
117
+ peak_tr = int(np.argmax(temporal)) if temporal else 0
118
+ peak_time = peak_tr * 2.0 + 5.0
119
 
120
  return {
121
  "scores": {
 
126
  "language_processing": round(lang, 1),
127
  "visual_imagery": round(vis, 1),
128
  "hook_effectiveness": round(hook, 1),
129
+ "retention_prediction": round(min(retention, 100), 1),
130
  },
131
  "raw": {
132
+ "n_timesteps": n_t,
133
+ "n_vertices": n_v,
134
+ "peak_engagement_time_s": round(peak_time, 1),
135
  "temporal_profile": [round(v, 1) for v in temporal],
136
+ "modalities_used": modalities,
137
+ "region_activations_raw": {k: round(v, 4) for k, v in region_scores.items()},
138
  },
139
  }
140
 
141
+
142
+ # ---- Visualization ----
143
+ def make_radar(scores, title="Brain Engagement"):
144
+ import matplotlib
145
+ matplotlib.use("Agg")
146
+ import matplotlib.pyplot as plt
147
+
148
+ cats = ["Attention", "Emotion", "Language", "Visual", "Viral"]
149
+ vals = [scores.get("attention_capture", 0), scores.get("emotional_valence", 0),
150
+ scores.get("language_processing", 0), scores.get("visual_imagery", 0),
151
+ scores.get("viral_potential", 0)]
152
+ vals += vals[:1]
153
+ angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0]
154
+
155
+ fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
156
+ fig.patch.set_facecolor("#0D1B2A")
157
+ ax.set_facecolor("#0D1B2A")
158
+ ax.plot(angles, vals, "o-", linewidth=2, color="#FFD166")
159
+ ax.fill(angles, vals, alpha=0.25, color="#FFD166")
160
+ ax.set_ylim(0, 100)
161
+ ax.set_xticks(angles[:-1])
162
+ ax.set_xticklabels(cats, size=11, color="white")
163
+ ax.set_yticks([25, 50, 75])
164
+ ax.set_yticklabels(["25", "50", "75"], size=8, color="grey")
165
+ ax.tick_params(colors="grey")
166
+ ax.spines["polar"].set_color("grey")
167
+ ax.grid(color="grey", alpha=0.3)
168
+ ax.set_title(title, size=14, color="white", pad=20)
169
+
170
+ buf = io.BytesIO()
171
+ fig.savefig(buf, format="png", bbox_inches="tight", facecolor=fig.get_facecolor(), dpi=100)
172
+ plt.close(fig)
173
+ buf.seek(0)
174
+ return buf
175
+
176
+
177
+ def make_summary(scores):
178
+ o = scores.get("overall_brain_engagement", 50)
179
+ parts = []
180
+ if o >= 70:
181
+ parts.append(f"Strong engagement ({o}/100).")
182
+ elif o >= 50:
183
+ parts.append(f"Decent engagement ({o}/100).")
184
+ else:
185
+ parts.append(f"Weak engagement ({o}/100).")
186
+ if scores.get("attention_capture", 50) >= 70:
187
+ parts.append("Great attention hook.")
188
+ elif scores.get("attention_capture", 50) < 40:
189
+ parts.append("Needs stronger opening hook.")
190
+ if scores.get("emotional_valence", 50) >= 70:
191
+ parts.append("Strong emotional trigger.")
192
+ elif scores.get("emotional_valence", 50) < 40:
193
+ parts.append("Add personal stakes or urgency.")
194
+ if scores.get("hook_effectiveness", 50) >= 70 and scores.get("retention_prediction", 50) < 50:
195
+ parts.append("Good hook but drops off mid-section.")
196
+ return " ".join(parts)
197
+
198
+
199
+ def _format(scores):
200
+ return "\n".join([
201
+ f"Overall Engagement: {scores['overall_brain_engagement']}/100",
202
+ f"Viral Potential: {scores['viral_potential']}/100",
203
+ f"Attention Capture: {scores['attention_capture']}/100",
204
+ f"Emotional Valence: {scores['emotional_valence']}/100",
205
+ f"Language Processing: {scores['language_processing']}/100",
206
+ f"Visual Imagery: {scores['visual_imagery']}/100",
207
+ f"Hook Effectiveness: {scores['hook_effectiveness']}/100",
208
+ f"Retention: {scores['retention_prediction']}/100",
209
+ ])
210
+
211
+
212
+ # ---- GPU Prediction Functions ----
213
+
214
+ @spaces.GPU(duration=60)
215
+ def _predict_text(text):
216
  m = ensure_model()
217
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".txt", delete=False, encoding="utf-8") as f:
218
+ f.write(text)
219
+ path = f.name
220
+ try:
221
+ df = m.get_events_dataframe(text_path=path)
222
+ preds, segs = m.predict(events=df)
223
+ finally:
224
+ os.unlink(path)
225
+ torch.cuda.empty_cache()
226
+ return interpret(preds, modalities=["text"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
 
 
 
 
228
 
229
+ @spaces.GPU(duration=120)
230
+ def _predict_video(video_path):
231
+ m = ensure_model()
232
+ df = m.get_events_dataframe(video_path=video_path)
233
+ preds, segs = m.predict(events=df)
234
  torch.cuda.empty_cache()
235
+ return interpret(preds, modalities=["video", "audio", "text"])
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
 
238
  # ---- Handlers ----
239
+
240
+ def handle_text(text):
241
  if not text or not text.strip():
242
+ return "Enter text to score.", None, ""
243
  try:
244
+ r = _predict_text(text.strip())
245
  s = r["scores"]
246
+ chart = make_radar(s)
247
+ return _format(s), chart, make_summary(s)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
248
  except Exception as e:
249
  import traceback
250
+ return f"Error: {e}\n{traceback.format_exc()}", None, ""
251
 
252
+
253
+ def handle_video(video):
254
+ if video is None:
255
+ return "Upload a video.", None, ""
256
  try:
257
+ r = _predict_video(video)
258
+ s = r["scores"]
259
+ chart = make_radar(s, title="Video Brain Engagement")
260
+ peak = r["raw"].get("peak_engagement_time_s", "N/A")
261
+ text = _format(s) + f"\nPeak Engagement: {peak}s"
262
+ return text, chart, make_summary(s)
263
  except Exception as e:
264
+ import traceback
265
+ return f"Error: {e}\n{traceback.format_exc()}", None, ""
266
 
267
+
268
+ def handle_ab(a, b):
269
  if not a or not b:
270
  return "Enter both versions."
271
  try:
272
+ ra = _predict_text(a.strip())
273
+ rb = _predict_text(b.strip())
274
  sa, sb = ra["scores"], rb["scores"]
275
  va, vb = sa["viral_potential"], sb["viral_potential"]
276
+ w = f"Version A wins ({va} vs {vb})" if va > vb else (
277
+ f"Version B wins ({vb} vs {va})" if vb > va else "Tie")
278
+ return f"{w}\n\n--- A ---\n{_format(sa)}\n{make_summary(sa)}\n\n--- B ---\n{_format(sb)}\n{make_summary(sb)}"
279
  except Exception as e:
280
  return f"Error: {e}"
281
 
282
+
283
+ def handle_api(text):
284
+ if not text or not text.strip():
285
+ return '{"error": "No text"}'
286
+ try:
287
+ r = _predict_text(text.strip())
288
+ return json.dumps(r, indent=2)
289
+ except Exception as e:
290
+ return json.dumps({"error": str(e)})
291
+
292
+
293
+ # ---- Gradio UI ----
294
+
295
+ with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base(
296
+ primary_hue="amber", secondary_hue="cyan", neutral_hue="slate",
297
+ font=gr.themes.GoogleFont("Inter"),
298
+ )) as demo:
299
+
300
+ gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n"
301
+ "Meta's fMRI model predicts how your content activates the brain.\n")
302
+
303
+ with gr.Tab("📝 Text Scorer"):
304
+ gr.Markdown("Score a script, hook, or post. ~30s on GPU.")
305
+ t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...")
306
+ t_btn = gr.Button("🧠 Analyze", variant="primary")
307
+ with gr.Row():
308
+ t_scores = gr.Textbox(label="Scores", lines=10)
309
+ t_chart = gr.Image(label="Brain Radar", type="filepath")
310
+ t_summary = gr.Textbox(label="Insight")
311
+ t_btn.click(handle_text, [t_in], [t_scores, t_chart, t_summary], api_name="predict")
312
+
313
+ with gr.Tab("🎬 Video Scorer"):
314
+ gr.Markdown("Upload a video for full multimodal brain analysis. ~2-5 min on GPU.")
315
+ v_in = gr.Video(label="Upload Video")
316
+ v_btn = gr.Button("🧠 Analyze Video", variant="primary")
317
+ with gr.Row():
318
+ v_scores = gr.Textbox(label="Scores", lines=10)
319
+ v_chart = gr.Image(label="Brain Radar", type="filepath")
320
+ v_summary = gr.Textbox(label="Insight")
321
+ v_btn.click(handle_video, [v_in], [v_scores, v_chart, v_summary], api_name="predict_video")
322
+
323
+ with gr.Tab("⚔️ A/B Test"):
324
+ gr.Markdown("Compare two hooks head-to-head.")
325
+ with gr.Row():
326
+ ab_a = gr.Textbox(label="Version A", lines=3)
327
+ ab_b = gr.Textbox(label="Version B", lines=3)
328
+ ab_btn = gr.Button("⚔️ Compare", variant="primary")
329
+ ab_out = gr.Textbox(label="Result", lines=10)
330
+ ab_btn.click(handle_ab, [ab_a, ab_b], [ab_out], api_name="ab_test")
331
+
332
+ with gr.Tab("🔌 API"):
333
+ gr.Markdown("Returns raw JSON for `score_script.py` compatibility.")
334
  api_in = gr.Textbox(label="Text", lines=3)
335
  api_btn = gr.Button("Get JSON")
336
+ api_out = gr.Textbox(label="JSON", lines=15)
337
+ api_btn.click(handle_api, [api_in], [api_out], api_name="api_predict")
338
+
339
+ gr.Markdown("---\n*[Meta TRIBE V2](https://github.com/facebookresearch/tribev2) | "
340
+ "ZeroGPU (A10G) | Python 3.12 | Built by somebeast*")
341
 
342
  demo.queue().launch()