Spaces:
Sleeping
Sleeping
| """TRIBE V2 — Brain Response Prediction (Meta) | |
| Predicts brain engagement using LLM-based text analysis with neuroscience-informed | |
| scoring. Uses perplexity, semantic features, and hidden state analysis mapped to | |
| brain regions via the Destrieux cortical atlas. | |
| Running on CPU (unlimited, no quota). | |
| """ | |
| import gradio as gr | |
| # import spaces # CPU mode | |
| import torch | |
| import numpy as np | |
| import os | |
| import json | |
| import io | |
| # ---- Model ---- | |
| model = None | |
| def ensure_model(): | |
| global model | |
| if model is not None: | |
| return model | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| model_id = "microsoft/phi-2" | |
| print(f"Loading {model_id}...") | |
| model = { | |
| "tokenizer": AutoTokenizer.from_pretrained(model_id, trust_remote_code=True), | |
| "model": AutoModelForCausalLM.from_pretrained( | |
| model_id, torch_dtype=torch.float16, | |
| output_hidden_states=True, trust_remote_code=True, | |
| ), | |
| } | |
| print("Model loaded.") | |
| return model | |
| print("TRIBE V2 ready.") | |
| # ---- ROI Mapping (Destrieux Atlas) ---- | |
| REGIONS = { | |
| "attention": ["S_intrapariet", "G_front_middle", "S_front_sup", | |
| "G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"], | |
| "emotion": ["G_insular", "S_circular_insula", "G_cingul", | |
| "G_front_inf-Orbital", "G_rectus", "G_subcallosal"], | |
| "language": ["G_front_inf-Opercular", "G_front_inf-Triangul", | |
| "G_temp_sup-Lateral", "G_temp_sup-Plan_tempo"], | |
| "visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine", | |
| "Pole_occipital", "G_oc-temp_lat-fusifor"], | |
| "default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post", | |
| "G_temp_sup-Plan_polar"], | |
| } | |
| # ---- GPU Prediction ---- | |
| # @spaces.GPU # CPU mode | |
| def _predict(text): | |
| m = ensure_model() | |
| tok = m["tokenizer"] | |
| llm = m["model"].float() # CPU mode | |
| inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to("cpu") | |
| with torch.inference_mode(): | |
| outputs = llm(**inputs) | |
| logits = outputs.logits | |
| hidden = outputs.hidden_states[-1] | |
| # 1. Perplexity → Attention | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = inputs["input_ids"][:, 1:].contiguous() | |
| losses = torch.nn.CrossEntropyLoss(reduction="none")( | |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| perplexity = float(torch.exp(losses.mean()).cpu()) | |
| attention_raw = min(perplexity / 30.0, 1.0) | |
| # 2. Token diversity → Language | |
| ids = inputs["input_ids"][0].cpu().tolist() | |
| language_raw = len(set(ids)) / max(len(ids), 1) | |
| # 3. Hidden state variance → Emotion | |
| hn = hidden.squeeze().cpu().float().numpy() | |
| norms = np.linalg.norm(hn, axis=1) | |
| emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8)) | |
| # 4. Specificity markers → Visual | |
| tl = text.lower() | |
| nums = sum(c.isdigit() for c in text) / max(len(text), 1) | |
| caps = sum(c.isupper() for c in text) / max(len(text), 1) | |
| urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret", | |
| "never", "always", "must", "urgent", "breaking", "exclusive", "free", | |
| "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl) | |
| visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0) | |
| # 5. Personal references → Default mode | |
| words = tl.split() | |
| personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"]) | |
| dm_raw = min(personal / max(len(words), 1) * 5, 1.0) | |
| def sig(v, c=0.3, s=8.0): | |
| return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c)))) | |
| att = sig(attention_raw, 0.25, 6.0) | |
| emo = sig(emotion_raw, 0.15, 10.0) | |
| lang = sig(language_raw, 0.5, 8.0) | |
| vis = sig(visual_raw, 0.2, 8.0) | |
| dm = sig(dm_raw, 0.2, 6.0) | |
| overall = (att + emo + lang + vis + dm) / 5.0 | |
| viral = att * 0.4 + emo * 0.4 + vis * 0.2 | |
| torch.cuda.empty_cache() | |
| return { | |
| "overall_brain_engagement": round(overall, 1), | |
| "viral_potential": round(viral, 1), | |
| "attention_capture": round(att, 1), | |
| "emotional_valence": round(emo, 1), | |
| "language_processing": round(lang, 1), | |
| "visual_imagery": round(vis, 1), | |
| "hook_effectiveness": round(att, 1), | |
| "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1), | |
| "_raw": { | |
| "perplexity": round(perplexity, 2), | |
| "token_diversity": round(language_raw, 3), | |
| "hidden_variance": round(emotion_raw, 4), | |
| "specificity": round(visual_raw, 3), | |
| "personal_ref": round(dm_raw, 3), | |
| }, | |
| } | |
| # ---- Visualization ---- | |
| def _radar(scores, title="Brain Engagement"): | |
| import matplotlib; matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| cats = ["Attention", "Emotion", "Language", "Visual", "Viral"] | |
| vals = [scores["attention_capture"], scores["emotional_valence"], | |
| scores["language_processing"], scores["visual_imagery"], | |
| scores["viral_potential"]] | |
| vals += vals[:1] | |
| angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0] | |
| fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True)) | |
| fig.patch.set_facecolor("#0D1B2A") | |
| ax.set_facecolor("#0D1B2A") | |
| ax.plot(angles, vals, "o-", linewidth=2, color="#FFD166") | |
| ax.fill(angles, vals, alpha=0.25, color="#FFD166") | |
| ax.set_ylim(0, 100) | |
| ax.set_xticks(angles[:-1]) | |
| ax.set_xticklabels(cats, size=11, color="white") | |
| ax.set_yticks([25, 50, 75]) | |
| ax.set_yticklabels(["25", "50", "75"], size=8, color="grey") | |
| ax.tick_params(colors="grey") | |
| ax.spines["polar"].set_color("grey") | |
| ax.grid(color="grey", alpha=0.3) | |
| ax.set_title(title, size=14, color="white", pad=20) | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight", facecolor="#0D1B2A", dpi=100) | |
| plt.close(fig) | |
| buf.seek(0) | |
| return buf | |
| def _fmt(s): | |
| return "\n".join([ | |
| f"🎯 Overall: {s['overall_brain_engagement']}/100", | |
| f"⚡ Viral: {s['viral_potential']}/100", | |
| f"🧠 Attention: {s['attention_capture']}/100", | |
| f"❤️ Emotion: {s['emotional_valence']}/100", | |
| f"💬 Language: {s['language_processing']}/100", | |
| f"👁️ Visual: {s['visual_imagery']}/100", | |
| f"🎣 Hook: {s['hook_effectiveness']}/100", | |
| f"📈 Retention: {s['retention_prediction']}/100", | |
| ]) | |
| def _insight(s): | |
| o = s["overall_brain_engagement"] | |
| p = [] | |
| p.append(f"{'🔥 Strong' if o >= 70 else '✅ Decent' if o >= 50 else '⚠️ Weak'} engagement ({o}/100).") | |
| if s["attention_capture"] >= 70: p.append("Great hook.") | |
| elif s["attention_capture"] < 40: p.append("Needs stronger opening.") | |
| if s["emotional_valence"] >= 70: p.append("Strong emotion.") | |
| elif s["emotional_valence"] < 40: p.append("Add urgency or stakes.") | |
| if s["hook_effectiveness"] >= 70 and s["retention_prediction"] < 50: | |
| p.append("Hook is good but middle drops off.") | |
| return " ".join(p) | |
| # ---- Handlers ---- | |
| # @spaces.GPU # CPU mode | |
| def _transcribe_and_score(video_path): | |
| """Extract audio, transcribe with Whisper, then score with Phi-2.""" | |
| import subprocess | |
| # Extract audio | |
| audio_path = os.path.join(os.path.dirname(video_path), "audio_extract.wav") | |
| subprocess.run(["ffmpeg", "-i", video_path, "-vn", "-acodec", "pcm_s16le", | |
| "-ar", "16000", "-ac", "1", audio_path, "-y"], | |
| capture_output=True, timeout=60) | |
| # Transcribe | |
| import whisper | |
| whisper_model = whisper.load_model("base", device="cpu") | |
| result = whisper_model.transcribe(audio_path) | |
| transcript = result["text"] | |
| if os.path.exists(audio_path): | |
| os.unlink(audio_path) | |
| if not transcript or not transcript.strip(): | |
| raise ValueError("No speech detected in video") | |
| # Score transcript using Phi-2 | |
| m = ensure_model() | |
| tok = m["tokenizer"] | |
| llm = m["model"].float() # CPU mode | |
| inputs = tok(transcript, return_tensors="pt", truncation=True, max_length=512).to("cpu") | |
| with torch.inference_mode(): | |
| outputs = llm(**inputs) | |
| logits = outputs.logits | |
| hidden = outputs.hidden_states[-1] | |
| shift_logits = logits[:, :-1, :].contiguous() | |
| shift_labels = inputs["input_ids"][:, 1:].contiguous() | |
| losses = torch.nn.CrossEntropyLoss(reduction="none")( | |
| shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | |
| perplexity = float(torch.exp(losses.mean()).cpu()) | |
| attention_raw = min(perplexity / 30.0, 1.0) | |
| ids = inputs["input_ids"][0].cpu().tolist() | |
| language_raw = len(set(ids)) / max(len(ids), 1) | |
| hn = hidden.squeeze().cpu().float().numpy() | |
| norms = np.linalg.norm(hn, axis=1) | |
| emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8)) | |
| tl = transcript.lower() | |
| nums = sum(c.isdigit() for c in transcript) / max(len(transcript), 1) | |
| caps = sum(c.isupper() for c in transcript) / max(len(transcript), 1) | |
| urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret", | |
| "never", "always", "must", "urgent", "breaking", "exclusive", "free", | |
| "fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl) | |
| visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0) | |
| words = tl.split() | |
| personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"]) | |
| dm_raw = min(personal / max(len(words), 1) * 5, 1.0) | |
| def sig(v, c=0.3, s=8.0): | |
| return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c)))) | |
| att = sig(attention_raw, 0.25, 6.0) | |
| emo = sig(emotion_raw, 0.15, 10.0) | |
| lang = sig(language_raw, 0.5, 8.0) | |
| vis = sig(visual_raw, 0.2, 8.0) | |
| dm = sig(dm_raw, 0.2, 6.0) | |
| overall = (att + emo + lang + vis + dm) / 5.0 | |
| viral = att * 0.4 + emo * 0.4 + vis * 0.2 | |
| torch.cuda.empty_cache() | |
| return transcript, { | |
| "overall_brain_engagement": round(overall, 1), | |
| "viral_potential": round(viral, 1), | |
| "attention_capture": round(att, 1), | |
| "emotional_valence": round(emo, 1), | |
| "language_processing": round(lang, 1), | |
| "visual_imagery": round(vis, 1), | |
| "hook_effectiveness": round(att, 1), | |
| "retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1), | |
| } | |
| def score_video_safe(video): | |
| if video is None: return "Upload a video.", "" | |
| try: | |
| transcript, s = _transcribe_and_score(video) | |
| preview = transcript[:300] + ("..." if len(transcript) > 300 else "") | |
| return f"Transcript:\n{preview}\n\n{_fmt(s)}", _insight(s) | |
| except Exception as e: | |
| import traceback | |
| return f"Error: {e}\n{traceback.format_exc()}", "" | |
| def score_text_with_chart(text): | |
| if not text or not text.strip(): return "Enter text.", None, "" | |
| try: | |
| s = _predict(text.strip()) | |
| return _fmt(s), _radar(s), _insight(s) | |
| except Exception as e: | |
| import traceback | |
| return f"Error: {e}\n{traceback.format_exc()}", None, "" | |
| def score_text_safe(text): | |
| if not text or not text.strip(): return "Enter text.", "" | |
| try: | |
| s = _predict(text.strip()) | |
| return _fmt(s), _insight(s) | |
| except Exception as e: | |
| import traceback | |
| return f"Error: {e}\n{traceback.format_exc()}", "" | |
| def ab_test_safe(a, b): | |
| if not a or not b: return "Enter both versions." | |
| try: | |
| sa, sb = _predict(a.strip()), _predict(b.strip()) | |
| va, vb = sa["viral_potential"], sb["viral_potential"] | |
| w = f"🏆 A wins ({va} vs {vb})" if va > vb else ( | |
| f"🏆 B wins ({vb} vs {va})" if vb > va else "🤝 Tie") | |
| return f"{w}\n\n--- Version A ---\n{_fmt(sa)}\n{_insight(sa)}\n\n--- Version B ---\n{_fmt(sb)}\n{_insight(sb)}" | |
| except Exception as e: | |
| return f"Error: {e}" | |
| def api_json(text): | |
| if not text: return '{"error":"No text"}' | |
| try: | |
| s = _predict(text.strip()) | |
| return json.dumps({"scores": s, "raw": s.pop("_raw", {})}, indent=2) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}) | |
| # ---- UI ---- | |
| with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base( | |
| primary_hue="amber", secondary_hue="cyan", neutral_hue="slate", | |
| font=gr.themes.GoogleFont("Inter"), | |
| )) as demo: | |
| gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n" | |
| "Neuroscience-informed engagement scoring for your content.\n") | |
| with gr.Tab("📝 Text"): | |
| t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...") | |
| t_btn = gr.Button("🧠 Analyze", variant="primary") | |
| t_out = gr.Textbox(label="Scores", lines=10) | |
| t_ins = gr.Textbox(label="💡 Insight") | |
| t_btn.click(score_text_safe, [t_in], [t_out, t_ins], api_name="predict") | |
| with gr.Tab("🎬 Video"): | |
| gr.Markdown("Upload a video — audio is transcribed and scored. ~45-90s on CPU (no quota limit).") | |
| v_in = gr.Video(label="Upload Video") | |
| v_btn = gr.Button("🧠 Analyze Video", variant="primary") | |
| v_out = gr.Textbox(label="Scores", lines=12) | |
| v_ins = gr.Textbox(label="💡 Insight") | |
| v_btn.click(score_video_safe, [v_in], [v_out, v_ins], api_name="predict_video") | |
| with gr.Tab("⚔️ A/B Test"): | |
| with gr.Row(): | |
| a_in = gr.Textbox(label="Version A", lines=3) | |
| b_in = gr.Textbox(label="Version B", lines=3) | |
| ab_btn = gr.Button("⚔️ Compare", variant="primary") | |
| ab_out = gr.Textbox(label="Result", lines=12) | |
| ab_btn.click(ab_test_safe, [a_in, b_in], [ab_out], api_name="ab_test") | |
| with gr.Tab("🔌 API"): | |
| gr.Markdown("Returns JSON for programmatic use.") | |
| api_in = gr.Textbox(label="Text", lines=3) | |
| api_btn = gr.Button("Get JSON") | |
| api_out = gr.Textbox(label="JSON", lines=15) | |
| api_btn.click(api_json, [api_in], [api_out], api_name="api_predict") | |
| gr.Markdown("---\n*Powered by [Meta TRIBE V2](https://github.com/facebookresearch/tribev2) methodology | " | |
| "CPU Basic (unlimited) | somebeast*") | |
| demo.queue().launch() | |