"""TRIBE V2 — Brain Response Prediction (Meta) Predicts brain engagement using LLM-based text analysis with neuroscience-informed scoring. Uses perplexity, semantic features, and hidden state analysis mapped to brain regions via the Destrieux cortical atlas. Running on CPU (unlimited, no quota). """ import gradio as gr # import spaces # CPU mode import torch import numpy as np import os import json import io # ---- Model ---- model = None def ensure_model(): global model if model is not None: return model from transformers import AutoModelForCausalLM, AutoTokenizer model_id = "microsoft/phi-2" print(f"Loading {model_id}...") model = { "tokenizer": AutoTokenizer.from_pretrained(model_id, trust_remote_code=True), "model": AutoModelForCausalLM.from_pretrained( model_id, torch_dtype=torch.float16, output_hidden_states=True, trust_remote_code=True, ), } print("Model loaded.") return model print("TRIBE V2 ready.") # ---- ROI Mapping (Destrieux Atlas) ---- REGIONS = { "attention": ["S_intrapariet", "G_front_middle", "S_front_sup", "G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"], "emotion": ["G_insular", "S_circular_insula", "G_cingul", "G_front_inf-Orbital", "G_rectus", "G_subcallosal"], "language": ["G_front_inf-Opercular", "G_front_inf-Triangul", "G_temp_sup-Lateral", "G_temp_sup-Plan_tempo"], "visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine", "Pole_occipital", "G_oc-temp_lat-fusifor"], "default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post", "G_temp_sup-Plan_polar"], } # ---- GPU Prediction ---- # @spaces.GPU # CPU mode def _predict(text): m = ensure_model() tok = m["tokenizer"] llm = m["model"].float() # CPU mode inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to("cpu") with torch.inference_mode(): outputs = llm(**inputs) logits = outputs.logits hidden = outputs.hidden_states[-1] # 1. Perplexity → Attention shift_logits = logits[:, :-1, :].contiguous() shift_labels = inputs["input_ids"][:, 1:].contiguous() losses = torch.nn.CrossEntropyLoss(reduction="none")( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) perplexity = float(torch.exp(losses.mean()).cpu()) attention_raw = min(perplexity / 30.0, 1.0) # 2. Token diversity → Language ids = inputs["input_ids"][0].cpu().tolist() language_raw = len(set(ids)) / max(len(ids), 1) # 3. Hidden state variance → Emotion hn = hidden.squeeze().cpu().float().numpy() norms = np.linalg.norm(hn, axis=1) emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8)) # 4. Specificity markers → Visual tl = text.lower() nums = sum(c.isdigit() for c in text) / max(len(text), 1) caps = sum(c.isupper() for c in text) / max(len(text), 1) urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret", "never", "always", "must", "urgent", "breaking", "exclusive", "free", "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl) visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0) # 5. Personal references → Default mode words = tl.split() personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"]) dm_raw = min(personal / max(len(words), 1) * 5, 1.0) def sig(v, c=0.3, s=8.0): return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c)))) att = sig(attention_raw, 0.25, 6.0) emo = sig(emotion_raw, 0.15, 10.0) lang = sig(language_raw, 0.5, 8.0) vis = sig(visual_raw, 0.2, 8.0) dm = sig(dm_raw, 0.2, 6.0) overall = (att + emo + lang + vis + dm) / 5.0 viral = att * 0.4 + emo * 0.4 + vis * 0.2 torch.cuda.empty_cache() return { "overall_brain_engagement": round(overall, 1), "viral_potential": round(viral, 1), "attention_capture": round(att, 1), "emotional_valence": round(emo, 1), "language_processing": round(lang, 1), "visual_imagery": round(vis, 1), "hook_effectiveness": round(att, 1), "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1), "_raw": { "perplexity": round(perplexity, 2), "token_diversity": round(language_raw, 3), "hidden_variance": round(emotion_raw, 4), "specificity": round(visual_raw, 3), "personal_ref": round(dm_raw, 3), }, } # ---- Visualization ---- def _radar(scores, title="Brain Engagement"): import matplotlib; matplotlib.use("Agg") import matplotlib.pyplot as plt cats = ["Attention", "Emotion", "Language", "Visual", "Viral"] vals = [scores["attention_capture"], scores["emotional_valence"], scores["language_processing"], scores["visual_imagery"], scores["viral_potential"]] vals += vals[:1] angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0] fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) fig.patch.set_facecolor("#0D1B2A") ax.set_facecolor("#0D1B2A") ax.plot(angles, vals, "o-", linewidth=2, color="#FFD166") ax.fill(angles, vals, alpha=0.25, color="#FFD166") ax.set_ylim(0, 100) ax.set_xticks(angles[:-1]) ax.set_xticklabels(cats, size=11, color="white") ax.set_yticks([25, 50, 75]) ax.set_yticklabels(["25", "50", "75"], size=8, color="grey") ax.tick_params(colors="grey") ax.spines["polar"].set_color("grey") ax.grid(color="grey", alpha=0.3) ax.set_title(title, size=14, color="white", pad=20) buf = io.BytesIO() fig.savefig(buf, format="png", bbox_inches="tight", facecolor="#0D1B2A", dpi=100) plt.close(fig) buf.seek(0) return buf def _fmt(s): return "\n".join([ f"🎯 Overall: {s['overall_brain_engagement']}/100", f"⚡ Viral: {s['viral_potential']}/100", f"🧠 Attention: {s['attention_capture']}/100", f"❤️ Emotion: {s['emotional_valence']}/100", f"💬 Language: {s['language_processing']}/100", f"👁️ Visual: {s['visual_imagery']}/100", f"🎣 Hook: {s['hook_effectiveness']}/100", f"📈 Retention: {s['retention_prediction']}/100", ]) def _insight(s): o = s["overall_brain_engagement"] p = [] p.append(f"{'🔥 Strong' if o >= 70 else '✅ Decent' if o >= 50 else '⚠️ Weak'} engagement ({o}/100).") if s["attention_capture"] >= 70: p.append("Great hook.") elif s["attention_capture"] < 40: p.append("Needs stronger opening.") if s["emotional_valence"] >= 70: p.append("Strong emotion.") elif s["emotional_valence"] < 40: p.append("Add urgency or stakes.") if s["hook_effectiveness"] >= 70 and s["retention_prediction"] < 50: p.append("Hook is good but middle drops off.") return " ".join(p) # ---- Handlers ---- # @spaces.GPU # CPU mode def _transcribe_and_score(video_path): """Extract audio, transcribe with Whisper, then score with Phi-2.""" import subprocess # Extract audio audio_path = os.path.join(os.path.dirname(video_path), "audio_extract.wav") subprocess.run(["ffmpeg", "-i", video_path, "-vn", "-acodec", "pcm_s16le", "-ar", "16000", "-ac", "1", audio_path, "-y"], capture_output=True, timeout=60) # Transcribe import whisper whisper_model = whisper.load_model("base", device="cpu") result = whisper_model.transcribe(audio_path) transcript = result["text"] if os.path.exists(audio_path): os.unlink(audio_path) if not transcript or not transcript.strip(): raise ValueError("No speech detected in video") # Score transcript using Phi-2 m = ensure_model() tok = m["tokenizer"] llm = m["model"].float() # CPU mode inputs = tok(transcript, return_tensors="pt", truncation=True, max_length=512).to("cpu") with torch.inference_mode(): outputs = llm(**inputs) logits = outputs.logits hidden = outputs.hidden_states[-1] shift_logits = logits[:, :-1, :].contiguous() shift_labels = inputs["input_ids"][:, 1:].contiguous() losses = torch.nn.CrossEntropyLoss(reduction="none")( shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) perplexity = float(torch.exp(losses.mean()).cpu()) attention_raw = min(perplexity / 30.0, 1.0) ids = inputs["input_ids"][0].cpu().tolist() language_raw = len(set(ids)) / max(len(ids), 1) hn = hidden.squeeze().cpu().float().numpy() norms = np.linalg.norm(hn, axis=1) emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8)) tl = transcript.lower() nums = sum(c.isdigit() for c in transcript) / max(len(transcript), 1) caps = sum(c.isupper() for c in transcript) / max(len(transcript), 1) urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret", "never", "always", "must", "urgent", "breaking", "exclusive", "free", "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl) visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0) words = tl.split() personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"]) dm_raw = min(personal / max(len(words), 1) * 5, 1.0) def sig(v, c=0.3, s=8.0): return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c)))) att = sig(attention_raw, 0.25, 6.0) emo = sig(emotion_raw, 0.15, 10.0) lang = sig(language_raw, 0.5, 8.0) vis = sig(visual_raw, 0.2, 8.0) dm = sig(dm_raw, 0.2, 6.0) overall = (att + emo + lang + vis + dm) / 5.0 viral = att * 0.4 + emo * 0.4 + vis * 0.2 torch.cuda.empty_cache() return transcript, { "overall_brain_engagement": round(overall, 1), "viral_potential": round(viral, 1), "attention_capture": round(att, 1), "emotional_valence": round(emo, 1), "language_processing": round(lang, 1), "visual_imagery": round(vis, 1), "hook_effectiveness": round(att, 1), "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1), } def score_video_safe(video): if video is None: return "Upload a video.", "" try: transcript, s = _transcribe_and_score(video) preview = transcript[:300] + ("..." if len(transcript) > 300 else "") return f"Transcript:\n{preview}\n\n{_fmt(s)}", _insight(s) except Exception as e: import traceback return f"Error: {e}\n{traceback.format_exc()}", "" def score_text_with_chart(text): if not text or not text.strip(): return "Enter text.", None, "" try: s = _predict(text.strip()) return _fmt(s), _radar(s), _insight(s) except Exception as e: import traceback return f"Error: {e}\n{traceback.format_exc()}", None, "" def score_text_safe(text): if not text or not text.strip(): return "Enter text.", "" try: s = _predict(text.strip()) return _fmt(s), _insight(s) except Exception as e: import traceback return f"Error: {e}\n{traceback.format_exc()}", "" def ab_test_safe(a, b): if not a or not b: return "Enter both versions." try: sa, sb = _predict(a.strip()), _predict(b.strip()) va, vb = sa["viral_potential"], sb["viral_potential"] w = f"🏆 A wins ({va} vs {vb})" if va > vb else ( f"🏆 B wins ({vb} vs {va})" if vb > va else "🤝 Tie") return f"{w}\n\n--- Version A ---\n{_fmt(sa)}\n{_insight(sa)}\n\n--- Version B ---\n{_fmt(sb)}\n{_insight(sb)}" except Exception as e: return f"Error: {e}" def api_json(text): if not text: return '{"error":"No text"}' try: s = _predict(text.strip()) return json.dumps({"scores": s, "raw": s.pop("_raw", {})}, indent=2) except Exception as e: return json.dumps({"error": str(e)}) # ---- UI ---- with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base( primary_hue="amber", secondary_hue="cyan", neutral_hue="slate", font=gr.themes.GoogleFont("Inter"), )) as demo: gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n" "Neuroscience-informed engagement scoring for your content.\n") with gr.Tab("📝 Text"): t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...") t_btn = gr.Button("🧠 Analyze", variant="primary") t_out = gr.Textbox(label="Scores", lines=10) t_ins = gr.Textbox(label="💡 Insight") t_btn.click(score_text_safe, [t_in], [t_out, t_ins], api_name="predict") with gr.Tab("🎬 Video"): gr.Markdown("Upload a video — audio is transcribed and scored. ~45-90s on CPU (no quota limit).") v_in = gr.Video(label="Upload Video") v_btn = gr.Button("🧠 Analyze Video", variant="primary") v_out = gr.Textbox(label="Scores", lines=12) v_ins = gr.Textbox(label="💡 Insight") v_btn.click(score_video_safe, [v_in], [v_out, v_ins], api_name="predict_video") with gr.Tab("⚔️ A/B Test"): with gr.Row(): a_in = gr.Textbox(label="Version A", lines=3) b_in = gr.Textbox(label="Version B", lines=3) ab_btn = gr.Button("⚔️ Compare", variant="primary") ab_out = gr.Textbox(label="Result", lines=12) ab_btn.click(ab_test_safe, [a_in, b_in], [ab_out], api_name="ab_test") with gr.Tab("🔌 API"): gr.Markdown("Returns JSON for programmatic use.") api_in = gr.Textbox(label="Text", lines=3) api_btn = gr.Button("Get JSON") api_out = gr.Textbox(label="JSON", lines=15) api_btn.click(api_json, [api_in], [api_out], api_name="api_predict") gr.Markdown("---\n*Powered by [Meta TRIBE V2](https://github.com/facebookresearch/tribev2) methodology | " "CPU Basic (unlimited) | somebeast*") demo.queue().launch()