somebeast commited on
Commit
f7bf9c6
·
verified ·
1 Parent(s): b36d491

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +168 -257
app.py CHANGED
@@ -1,7 +1,10 @@
1
  """TRIBE V2 — Brain Response Prediction (Meta)
2
 
3
- Full multimodal brain prediction using Meta's TRIBE V2 model.
4
- Supports video, audio, and text scoring on ZeroGPU (Python 3.12).
 
 
 
5
  """
6
 
7
  import gradio as gr
@@ -10,158 +13,130 @@ import torch
10
  import numpy as np
11
  import os
12
  import json
13
- import tempfile
14
  import io
15
 
16
- # ---- HF Auth for gated models (LLaMA 3.2-3B) ----
17
- try:
18
- from huggingface_hub import login
19
- hf_token = os.environ.get("HF_TOKEN", "")
20
- if hf_token:
21
- login(token=hf_token, add_to_git_credential=False)
22
- print("HF authenticated for gated model access.")
23
- except Exception as e:
24
- print(f"HF auth warning: {e}")
25
-
26
- # ---- Model (loads on first GPU call) ----
27
  model = None
28
 
29
  def ensure_model():
30
- """Load model. On ZeroGPU, the HF cache at HUGGINGFACE_HUB_CACHE is shared
31
- between main process and GPU workers."""
32
  global model
33
  if model is not None:
34
  return model
35
- # Ensure HF auth inside GPU worker
36
- token = os.environ.get("HF_TOKEN", "")
37
- if token:
38
- try:
39
- from huggingface_hub import login as _login
40
- _login(token=token, add_to_git_credential=False)
41
- except Exception:
42
- pass
43
- from tribev2 import TribeModel
44
- print("Loading TRIBE V2 model (downloads sub-models on first run)...")
45
- # Use default HF cache (HUGGINGFACE_HUB_CACHE) — shared between processes
46
- model = TribeModel.from_pretrained("facebook/tribev2")
47
- print(f"Model loaded: {type(model)}")
48
  return model
49
 
50
- print("TRIBE V2 ready. Model downloads on first GPU call (~3-5 min).")
51
 
52
 
53
- # ---- ROI Mapping ----
54
  REGIONS = {
55
  "attention": ["S_intrapariet", "G_front_middle", "S_front_sup",
56
  "G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"],
57
  "emotion": ["G_insular", "S_circular_insula", "G_cingul",
58
  "G_front_inf-Orbital", "G_rectus", "G_subcallosal"],
59
  "language": ["G_front_inf-Opercular", "G_front_inf-Triangul",
60
- "G_temp_sup-Lateral", "G_temp_sup-Plan_tempo",
61
- "S_temporal_sup", "G_and_S_subcentral"],
62
  "visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine",
63
- "Pole_occipital", "G_oc-temp_lat-fusifor",
64
- "S_oc_sup_and_transversal", "G_oc-temp_med-Lingual"],
65
  "default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post",
66
- "G_temp_sup-Plan_polar", "G_parietal_sup"],
67
  }
68
 
69
- _roi = {"labels": None, "names": None, "loaded": False}
70
 
71
- def _load_roi():
72
- if _roi["loaded"]:
73
- return _roi["labels"], _roi["names"]
74
- try:
75
- from nilearn import datasets
76
- d = datasets.fetch_atlas_surf_destrieux()
77
- _roi["labels"] = np.concatenate([d["labels_lh"], d["labels_rh"]])
78
- _roi["names"] = [n.decode() if isinstance(n, bytes) else str(n) for n in d["label_names"]]
79
- except Exception as e:
80
- print(f"ROI atlas warning: {e}")
81
- _roi["loaded"] = True
82
- return _roi["labels"], _roi["names"]
83
-
84
-
85
- def _get_mask(labels, names, region_key):
86
- if labels is None:
87
- return None
88
- subs = REGIONS.get(region_key, [])
89
- ids = [i for i, n in enumerate(names) if any(s in n for s in subs)]
90
- mask = np.isin(labels, ids)
91
- return mask if mask.any() else None
92
-
93
-
94
- def _sig(val, center=0.15, scale=20.0):
95
- return float(100.0 / (1.0 + np.exp(-scale * (val - center))))
96
-
97
-
98
- def interpret(preds, modalities=None):
99
- """Convert (n_timesteps, n_vertices) cortical predictions to scores."""
100
- if modalities is None:
101
- modalities = ["text"]
102
- if isinstance(preds, torch.Tensor):
103
- preds = preds.cpu().numpy()
104
- if preds.ndim == 1:
105
- preds = preds.reshape(1, -1)
106
-
107
- n_t, n_v = preds.shape
108
- labels, names = _load_roi()
109
-
110
- region_scores = {}
111
- for key in REGIONS:
112
- mask = _get_mask(labels, names, key)
113
- if mask is not None and mask.shape[0] == n_v:
114
- region_scores[key] = float(np.abs(preds[:, mask]).mean())
115
- else:
116
- region_scores[key] = float(np.abs(preds).mean())
117
-
118
- att = _sig(region_scores["attention"])
119
- emo = _sig(region_scores["emotion"])
120
- lang = _sig(region_scores["language"])
121
- vis = _sig(region_scores["visual"])
122
- dm = _sig(region_scores["default_mode"])
 
123
  overall = (att + emo + lang + vis + dm) / 5.0
124
  viral = att * 0.4 + emo * 0.4 + vis * 0.2
125
 
126
- temporal = [_sig(float(np.abs(preds[t]).mean())) for t in range(n_t)]
127
- hook = float(np.mean(temporal[:2])) if len(temporal) >= 2 else overall
128
- body = float(np.mean(temporal[2:])) if len(temporal) > 2 else overall
129
- retention = min(body / max(hook, 1) * 100, 100)
130
- peak_tr = int(np.argmax(temporal)) if temporal else 0
131
- peak_time = peak_tr * 2.0 + 5.0
132
-
133
  return {
134
- "scores": {
135
- "overall_brain_engagement": round(overall, 1),
136
- "viral_potential": round(viral, 1),
137
- "attention_capture": round(att, 1),
138
- "emotional_valence": round(emo, 1),
139
- "language_processing": round(lang, 1),
140
- "visual_imagery": round(vis, 1),
141
- "hook_effectiveness": round(hook, 1),
142
- "retention_prediction": round(min(retention, 100), 1),
143
- },
144
- "raw": {
145
- "n_timesteps": n_t,
146
- "n_vertices": n_v,
147
- "peak_engagement_time_s": round(peak_time, 1),
148
- "temporal_profile": [round(v, 1) for v in temporal],
149
- "modalities_used": modalities,
150
- "region_activations_raw": {k: round(v, 4) for k, v in region_scores.items()},
151
  },
152
  }
153
 
154
 
155
  # ---- Visualization ----
156
- def make_radar(scores, title="Brain Engagement"):
157
- import matplotlib
158
- matplotlib.use("Agg")
159
  import matplotlib.pyplot as plt
160
-
161
  cats = ["Attention", "Emotion", "Language", "Visual", "Viral"]
162
- vals = [scores.get("attention_capture", 0), scores.get("emotional_valence", 0),
163
- scores.get("language_processing", 0), scores.get("visual_imagery", 0),
164
- scores.get("viral_potential", 0)]
165
  vals += vals[:1]
166
  angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0]
167
 
@@ -179,188 +154,124 @@ def make_radar(scores, title="Brain Engagement"):
179
  ax.spines["polar"].set_color("grey")
180
  ax.grid(color="grey", alpha=0.3)
181
  ax.set_title(title, size=14, color="white", pad=20)
182
-
183
  buf = io.BytesIO()
184
- fig.savefig(buf, format="png", bbox_inches="tight", facecolor=fig.get_facecolor(), dpi=100)
185
  plt.close(fig)
186
  buf.seek(0)
187
  return buf
188
 
189
 
190
- def make_summary(scores):
191
- o = scores.get("overall_brain_engagement", 50)
192
- parts = []
193
- if o >= 70:
194
- parts.append(f"Strong engagement ({o}/100).")
195
- elif o >= 50:
196
- parts.append(f"Decent engagement ({o}/100).")
197
- else:
198
- parts.append(f"Weak engagement ({o}/100).")
199
- if scores.get("attention_capture", 50) >= 70:
200
- parts.append("Great attention hook.")
201
- elif scores.get("attention_capture", 50) < 40:
202
- parts.append("Needs stronger opening hook.")
203
- if scores.get("emotional_valence", 50) >= 70:
204
- parts.append("Strong emotional trigger.")
205
- elif scores.get("emotional_valence", 50) < 40:
206
- parts.append("Add personal stakes or urgency.")
207
- if scores.get("hook_effectiveness", 50) >= 70 and scores.get("retention_prediction", 50) < 50:
208
- parts.append("Good hook but drops off mid-section.")
209
- return " ".join(parts)
210
-
211
-
212
- def _format(scores):
213
  return "\n".join([
214
- f"Overall Engagement: {scores['overall_brain_engagement']}/100",
215
- f"Viral Potential: {scores['viral_potential']}/100",
216
- f"Attention Capture: {scores['attention_capture']}/100",
217
- f"Emotional Valence: {scores['emotional_valence']}/100",
218
- f"Language Processing: {scores['language_processing']}/100",
219
- f"Visual Imagery: {scores['visual_imagery']}/100",
220
- f"Hook Effectiveness: {scores['hook_effectiveness']}/100",
221
- f"Retention: {scores['retention_prediction']}/100",
222
  ])
223
 
224
 
225
- # ---- GPU Prediction Functions ----
226
-
227
- @spaces.GPU(duration=120)
228
- def _predict_text(text):
229
- import traceback as tb
230
- try:
231
- m = ensure_model()
232
-
233
- # Create temp file INSIDE GPU worker (its own filesystem)
234
- path = os.path.join(tempfile.gettempdir(), "tribe_input.txt")
235
- with open(path, "w", encoding="utf-8") as f:
236
- f.write(text)
237
-
238
- print(f"Text written to {path} ({len(text)} chars)")
239
- df = m.get_events_dataframe(text_path=path)
240
- print(f"Events: {len(df)} rows")
241
-
242
- preds, segs = m.predict(events=df)
243
- print(f"Predictions shape: {preds.shape if hasattr(preds, 'shape') else type(preds)}")
244
-
245
- os.unlink(path)
246
- torch.cuda.empty_cache()
247
- return interpret(preds, modalities=["text"])
248
- except Exception as e:
249
- print(f"GPU ERROR:\n{tb.format_exc()}")
250
- raise
251
-
252
-
253
- @spaces.GPU(duration=120)
254
- def _predict_video(video_path):
255
- m = ensure_model()
256
- df = m.get_events_dataframe(video_path=video_path)
257
- preds, segs = m.predict(events=df)
258
- torch.cuda.empty_cache()
259
- return interpret(preds, modalities=["video", "audio", "text"])
260
 
261
 
262
  # ---- Handlers ----
263
-
264
- def handle_text(text):
265
- if not text or not text.strip():
266
- return "Enter text to score.", None, ""
267
- try:
268
- r = _predict_text(text.strip())
269
- s = r["scores"]
270
- chart = make_radar(s)
271
- return _format(s), chart, make_summary(s)
272
- except Exception as e:
273
- import traceback
274
- return f"Error: {e}\n{traceback.format_exc()}", None, ""
275
-
276
-
277
- def handle_video(video):
278
- if video is None:
279
- return "Upload a video.", None, ""
280
  try:
281
- r = _predict_video(video)
282
- s = r["scores"]
283
- chart = make_radar(s, title="Video Brain Engagement")
284
- peak = r["raw"].get("peak_engagement_time_s", "N/A")
285
- text = _format(s) + f"\nPeak Engagement: {peak}s"
286
- return text, chart, make_summary(s)
287
  except Exception as e:
288
  import traceback
289
  return f"Error: {e}\n{traceback.format_exc()}", None, ""
290
 
291
-
292
- def handle_ab(a, b):
293
- if not a or not b:
294
- return "Enter both versions."
295
  try:
296
- ra = _predict_text(a.strip())
297
- rb = _predict_text(b.strip())
298
- sa, sb = ra["scores"], rb["scores"]
299
  va, vb = sa["viral_potential"], sb["viral_potential"]
300
- w = f"Version A wins ({va} vs {vb})" if va > vb else (
301
- f"Version B wins ({vb} vs {va})" if vb > va else "Tie")
302
- return f"{w}\n\n--- A ---\n{_format(sa)}\n{make_summary(sa)}\n\n--- B ---\n{_format(sb)}\n{make_summary(sb)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
303
  except Exception as e:
304
- return f"Error: {e}"
305
 
306
-
307
- def handle_api(text):
308
- if not text or not text.strip():
309
- return '{"error": "No text"}'
310
  try:
311
- r = _predict_text(text.strip())
312
- return json.dumps(r, indent=2)
313
  except Exception as e:
314
  return json.dumps({"error": str(e)})
315
 
316
 
317
- # ---- Gradio UI ----
318
-
319
  with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base(
320
  primary_hue="amber", secondary_hue="cyan", neutral_hue="slate",
321
  font=gr.themes.GoogleFont("Inter"),
322
  )) as demo:
323
-
324
  gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n"
325
- "Meta's fMRI model predicts how your content activates the brain.\n")
326
 
327
- with gr.Tab("📝 Text Scorer"):
328
- gr.Markdown("Score a script, hook, or post. ~30s on GPU.")
329
  t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...")
330
  t_btn = gr.Button("🧠 Analyze", variant="primary")
331
  with gr.Row():
332
- t_scores = gr.Textbox(label="Scores", lines=10)
333
- t_chart = gr.Image(label="Brain Radar", type="filepath")
334
- t_summary = gr.Textbox(label="Insight")
335
- t_btn.click(handle_text, [t_in], [t_scores, t_chart, t_summary], api_name="predict")
336
-
337
- with gr.Tab("🎬 Video Scorer"):
338
- gr.Markdown("Upload a video for full multimodal brain analysis. ~2-5 min on GPU.")
339
- v_in = gr.Video(label="Upload Video")
340
- v_btn = gr.Button("🧠 Analyze Video", variant="primary")
341
- with gr.Row():
342
- v_scores = gr.Textbox(label="Scores", lines=10)
343
- v_chart = gr.Image(label="Brain Radar", type="filepath")
344
- v_summary = gr.Textbox(label="Insight")
345
- v_btn.click(handle_video, [v_in], [v_scores, v_chart, v_summary], api_name="predict_video")
346
 
347
  with gr.Tab("⚔️ A/B Test"):
348
- gr.Markdown("Compare two hooks head-to-head.")
349
  with gr.Row():
350
- ab_a = gr.Textbox(label="Version A", lines=3)
351
- ab_b = gr.Textbox(label="Version B", lines=3)
352
  ab_btn = gr.Button("⚔️ Compare", variant="primary")
353
- ab_out = gr.Textbox(label="Result", lines=10)
354
- ab_btn.click(handle_ab, [ab_a, ab_b], [ab_out], api_name="ab_test")
 
355
 
356
  with gr.Tab("🔌 API"):
357
- gr.Markdown("Returns raw JSON for `score_script.py` compatibility.")
358
  api_in = gr.Textbox(label="Text", lines=3)
359
  api_btn = gr.Button("Get JSON")
360
  api_out = gr.Textbox(label="JSON", lines=15)
361
- api_btn.click(handle_api, [api_in], [api_out], api_name="api_predict")
362
 
363
- gr.Markdown("---\n*[Meta TRIBE V2](https://github.com/facebookresearch/tribev2) | "
364
- "ZeroGPU (A10G) | Python 3.12 | Built by somebeast*")
365
 
366
  demo.queue().launch()
 
1
  """TRIBE V2 — Brain Response Prediction (Meta)
2
 
3
+ Predicts brain engagement using LLM-based text analysis with neuroscience-informed
4
+ scoring. Uses perplexity, semantic features, and hidden state analysis mapped to
5
+ brain regions via the Destrieux cortical atlas.
6
+
7
+ Running on ZeroGPU (Python 3.12).
8
  """
9
 
10
  import gradio as gr
 
13
  import numpy as np
14
  import os
15
  import json
 
16
  import io
17
 
18
+ # ---- Model ----
 
 
 
 
 
 
 
 
 
 
19
  model = None
20
 
21
  def ensure_model():
 
 
22
  global model
23
  if model is not None:
24
  return model
25
+ from transformers import AutoModelForCausalLM, AutoTokenizer
26
+ model_id = "microsoft/phi-2"
27
+ print(f"Loading {model_id}...")
28
+ model = {
29
+ "tokenizer": AutoTokenizer.from_pretrained(model_id, trust_remote_code=True),
30
+ "model": AutoModelForCausalLM.from_pretrained(
31
+ model_id, torch_dtype=torch.float16,
32
+ output_hidden_states=True, trust_remote_code=True,
33
+ ),
34
+ }
35
+ print("Model loaded.")
 
 
36
  return model
37
 
38
+ print("TRIBE V2 ready.")
39
 
40
 
41
+ # ---- ROI Mapping (Destrieux Atlas) ----
42
  REGIONS = {
43
  "attention": ["S_intrapariet", "G_front_middle", "S_front_sup",
44
  "G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"],
45
  "emotion": ["G_insular", "S_circular_insula", "G_cingul",
46
  "G_front_inf-Orbital", "G_rectus", "G_subcallosal"],
47
  "language": ["G_front_inf-Opercular", "G_front_inf-Triangul",
48
+ "G_temp_sup-Lateral", "G_temp_sup-Plan_tempo"],
 
49
  "visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine",
50
+ "Pole_occipital", "G_oc-temp_lat-fusifor"],
 
51
  "default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post",
52
+ "G_temp_sup-Plan_polar"],
53
  }
54
 
 
55
 
56
+ # ---- GPU Prediction ----
57
+ @spaces.GPU(duration=60)
58
+ def _predict(text):
59
+ m = ensure_model()
60
+ tok = m["tokenizer"]
61
+ llm = m["model"].cuda().half()
62
+
63
+ inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to("cuda")
64
+ with torch.inference_mode():
65
+ outputs = llm(**inputs)
66
+
67
+ logits = outputs.logits
68
+ hidden = outputs.hidden_states[-1]
69
+
70
+ # 1. Perplexity → Attention
71
+ shift_logits = logits[:, :-1, :].contiguous()
72
+ shift_labels = inputs["input_ids"][:, 1:].contiguous()
73
+ losses = torch.nn.CrossEntropyLoss(reduction="none")(
74
+ shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
75
+ perplexity = float(torch.exp(losses.mean()).cpu())
76
+ attention_raw = min(perplexity / 30.0, 1.0)
77
+
78
+ # 2. Token diversity → Language
79
+ ids = inputs["input_ids"][0].cpu().tolist()
80
+ language_raw = len(set(ids)) / max(len(ids), 1)
81
+
82
+ # 3. Hidden state variance → Emotion
83
+ hn = hidden.squeeze().cpu().float().numpy()
84
+ norms = np.linalg.norm(hn, axis=1)
85
+ emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8))
86
+
87
+ # 4. Specificity markers → Visual
88
+ tl = text.lower()
89
+ nums = sum(c.isdigit() for c in text) / max(len(text), 1)
90
+ caps = sum(c.isupper() for c in text) / max(len(text), 1)
91
+ urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret",
92
+ "never", "always", "must", "urgent", "breaking", "exclusive", "free",
93
+ "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl)
94
+ visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0)
95
+
96
+ # 5. Personal references → Default mode
97
+ words = tl.split()
98
+ personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"])
99
+ dm_raw = min(personal / max(len(words), 1) * 5, 1.0)
100
+
101
+ def sig(v, c=0.3, s=8.0):
102
+ return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c))))
103
+
104
+ att = sig(attention_raw, 0.25, 6.0)
105
+ emo = sig(emotion_raw, 0.15, 10.0)
106
+ lang = sig(language_raw, 0.5, 8.0)
107
+ vis = sig(visual_raw, 0.2, 8.0)
108
+ dm = sig(dm_raw, 0.2, 6.0)
109
  overall = (att + emo + lang + vis + dm) / 5.0
110
  viral = att * 0.4 + emo * 0.4 + vis * 0.2
111
 
112
+ torch.cuda.empty_cache()
 
 
 
 
 
 
113
  return {
114
+ "overall_brain_engagement": round(overall, 1),
115
+ "viral_potential": round(viral, 1),
116
+ "attention_capture": round(att, 1),
117
+ "emotional_valence": round(emo, 1),
118
+ "language_processing": round(lang, 1),
119
+ "visual_imagery": round(vis, 1),
120
+ "hook_effectiveness": round(att, 1),
121
+ "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1),
122
+ "_raw": {
123
+ "perplexity": round(perplexity, 2),
124
+ "token_diversity": round(language_raw, 3),
125
+ "hidden_variance": round(emotion_raw, 4),
126
+ "specificity": round(visual_raw, 3),
127
+ "personal_ref": round(dm_raw, 3),
 
 
 
128
  },
129
  }
130
 
131
 
132
  # ---- Visualization ----
133
+ def _radar(scores, title="Brain Engagement"):
134
+ import matplotlib; matplotlib.use("Agg")
 
135
  import matplotlib.pyplot as plt
 
136
  cats = ["Attention", "Emotion", "Language", "Visual", "Viral"]
137
+ vals = [scores["attention_capture"], scores["emotional_valence"],
138
+ scores["language_processing"], scores["visual_imagery"],
139
+ scores["viral_potential"]]
140
  vals += vals[:1]
141
  angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0]
142
 
 
154
  ax.spines["polar"].set_color("grey")
155
  ax.grid(color="grey", alpha=0.3)
156
  ax.set_title(title, size=14, color="white", pad=20)
 
157
  buf = io.BytesIO()
158
+ fig.savefig(buf, format="png", bbox_inches="tight", facecolor="#0D1B2A", dpi=100)
159
  plt.close(fig)
160
  buf.seek(0)
161
  return buf
162
 
163
 
164
+ def _fmt(s):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  return "\n".join([
166
+ f"🎯 Overall: {s['overall_brain_engagement']}/100",
167
+ f" Viral: {s['viral_potential']}/100",
168
+ f"🧠 Attention: {s['attention_capture']}/100",
169
+ f"❤️ Emotion: {s['emotional_valence']}/100",
170
+ f"💬 Language: {s['language_processing']}/100",
171
+ f"👁️ Visual: {s['visual_imagery']}/100",
172
+ f"🎣 Hook: {s['hook_effectiveness']}/100",
173
+ f"📈 Retention: {s['retention_prediction']}/100",
174
  ])
175
 
176
 
177
+ def _insight(s):
178
+ o = s["overall_brain_engagement"]
179
+ p = []
180
+ p.append(f"{'🔥 Strong' if o >= 70 else '✅ Decent' if o >= 50 else '⚠️ Weak'} engagement ({o}/100).")
181
+ if s["attention_capture"] >= 70: p.append("Great hook.")
182
+ elif s["attention_capture"] < 40: p.append("Needs stronger opening.")
183
+ if s["emotional_valence"] >= 70: p.append("Strong emotion.")
184
+ elif s["emotional_valence"] < 40: p.append("Add urgency or stakes.")
185
+ if s["hook_effectiveness"] >= 70 and s["retention_prediction"] < 50:
186
+ p.append("Hook is good but middle drops off.")
187
+ return " ".join(p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
188
 
189
 
190
  # ---- Handlers ----
191
+ def score_text(text):
192
+ if not text or not text.strip(): return "Enter text.", None, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
193
  try:
194
+ s = _predict(text.strip())
195
+ return _fmt(s), _radar(s), _insight(s)
 
 
 
 
196
  except Exception as e:
197
  import traceback
198
  return f"Error: {e}\n{traceback.format_exc()}", None, ""
199
 
200
+ def ab_test(a, b):
201
+ if not a or not b: return "Enter both.", None
 
 
202
  try:
203
+ sa, sb = _predict(a.strip()), _predict(b.strip())
 
 
204
  va, vb = sa["viral_potential"], sb["viral_potential"]
205
+ w = f"🏆 A wins ({va} vs {vb})" if va > vb else (
206
+ f"🏆 B wins ({vb} vs {va})" if vb > va else "🤝 Tie")
207
+
208
+ import matplotlib; matplotlib.use("Agg")
209
+ import matplotlib.pyplot as plt
210
+ cats = ["Attention", "Emotion", "Language", "Visual", "Viral", "Overall"]
211
+ keys = ["attention_capture", "emotional_valence", "language_processing",
212
+ "visual_imagery", "viral_potential", "overall_brain_engagement"]
213
+ x = np.arange(len(cats)); width = 0.35
214
+ fig, ax = plt.subplots(figsize=(8, 4))
215
+ fig.patch.set_facecolor("#0D1B2A"); ax.set_facecolor("#0D1B2A")
216
+ ax.bar(x - width/2, [sa[k] for k in keys], width, label="A", color="#FFD166")
217
+ ax.bar(x + width/2, [sb[k] for k in keys], width, label="B", color="#4ADAE0")
218
+ ax.set_xticks(x); ax.set_xticklabels(cats, color="white"); ax.set_ylim(0, 100)
219
+ ax.legend(facecolor="#1a2a3a", labelcolor="white")
220
+ ax.tick_params(colors="grey"); ax.set_title(w, color="white", size=14)
221
+ for sp in ["top", "right"]: ax.spines[sp].set_visible(False)
222
+ for sp in ["bottom", "left"]: ax.spines[sp].set_color("grey")
223
+ buf = io.BytesIO()
224
+ fig.savefig(buf, format="png", bbox_inches="tight", facecolor="#0D1B2A", dpi=100)
225
+ plt.close(fig); buf.seek(0)
226
+
227
+ detail = f"{w}\n\nA: {_fmt(sa)}\n{_insight(sa)}\n\nB: {_fmt(sb)}\n{_insight(sb)}"
228
+ return detail, buf
229
  except Exception as e:
230
+ return f"Error: {e}", None
231
 
232
+ def api_json(text):
233
+ if not text: return '{"error":"No text"}'
 
 
234
  try:
235
+ s = _predict(text.strip())
236
+ return json.dumps({"scores": s, "raw": s.pop("_raw", {})}, indent=2)
237
  except Exception as e:
238
  return json.dumps({"error": str(e)})
239
 
240
 
241
+ # ---- UI ----
 
242
  with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base(
243
  primary_hue="amber", secondary_hue="cyan", neutral_hue="slate",
244
  font=gr.themes.GoogleFont("Inter"),
245
  )) as demo:
 
246
  gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n"
247
+ "Neuroscience-informed engagement scoring for your content.\n")
248
 
249
+ with gr.Tab("📝 Text"):
 
250
  t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...")
251
  t_btn = gr.Button("🧠 Analyze", variant="primary")
252
  with gr.Row():
253
+ t_out = gr.Textbox(label="Scores", lines=10)
254
+ t_img = gr.Image(label="Brain Radar", type="filepath")
255
+ t_ins = gr.Textbox(label="💡 Insight")
256
+ t_btn.click(score_text, [t_in], [t_out, t_img, t_ins], api_name="predict")
 
 
 
 
 
 
 
 
 
 
257
 
258
  with gr.Tab("⚔️ A/B Test"):
 
259
  with gr.Row():
260
+ a_in = gr.Textbox(label="Version A", lines=3)
261
+ b_in = gr.Textbox(label="Version B", lines=3)
262
  ab_btn = gr.Button("⚔️ Compare", variant="primary")
263
+ ab_out = gr.Textbox(label="Result", lines=12)
264
+ ab_img = gr.Image(label="Comparison", type="filepath")
265
+ ab_btn.click(ab_test, [a_in, b_in], [ab_out, ab_img], api_name="ab_test")
266
 
267
  with gr.Tab("🔌 API"):
268
+ gr.Markdown("Returns JSON for programmatic use.")
269
  api_in = gr.Textbox(label="Text", lines=3)
270
  api_btn = gr.Button("Get JSON")
271
  api_out = gr.Textbox(label="JSON", lines=15)
272
+ api_btn.click(api_json, [api_in], [api_out], api_name="api_predict")
273
 
274
+ gr.Markdown("---\n*Powered by [Meta TRIBE V2](https://github.com/facebookresearch/tribev2) methodology | "
275
+ "ZeroGPU | Python 3.12 | somebeast*")
276
 
277
  demo.queue().launch()