tribe-v2-cpu / app.py
somebeast's picture
Upload app.py with huggingface_hub
5470a27 verified
"""TRIBE V2 — Brain Response Prediction (Meta)
Predicts brain engagement using LLM-based text analysis with neuroscience-informed
scoring. Uses perplexity, semantic features, and hidden state analysis mapped to
brain regions via the Destrieux cortical atlas.
Running on CPU (unlimited, no quota).
"""
import gradio as gr
# import spaces # CPU mode
import torch
import numpy as np
import os
import json
import io
# ---- Model ----
model = None
def ensure_model():
global model
if model is not None:
return model
from transformers import AutoModelForCausalLM, AutoTokenizer
model_id = "microsoft/phi-2"
print(f"Loading {model_id}...")
model = {
"tokenizer": AutoTokenizer.from_pretrained(model_id, trust_remote_code=True),
"model": AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype=torch.float16,
output_hidden_states=True, trust_remote_code=True,
),
}
print("Model loaded.")
return model
print("TRIBE V2 ready.")
# ---- ROI Mapping (Destrieux Atlas) ----
REGIONS = {
"attention": ["S_intrapariet", "G_front_middle", "S_front_sup",
"G_pariet_inf-Supramar", "G_temp_sup-G_T_transv"],
"emotion": ["G_insular", "S_circular_insula", "G_cingul",
"G_front_inf-Orbital", "G_rectus", "G_subcallosal"],
"language": ["G_front_inf-Opercular", "G_front_inf-Triangul",
"G_temp_sup-Lateral", "G_temp_sup-Plan_tempo"],
"visual": ["G_occipital", "S_occipital", "G_cuneus", "S_calcarine",
"Pole_occipital", "G_oc-temp_lat-fusifor"],
"default_mode": ["G_front_sup", "G_precuneus", "G_cingul-Post",
"G_temp_sup-Plan_polar"],
}
# ---- GPU Prediction ----
# @spaces.GPU # CPU mode
def _predict(text):
m = ensure_model()
tok = m["tokenizer"]
llm = m["model"].float() # CPU mode
inputs = tok(text, return_tensors="pt", truncation=True, max_length=512).to("cpu")
with torch.inference_mode():
outputs = llm(**inputs)
logits = outputs.logits
hidden = outputs.hidden_states[-1]
# 1. Perplexity → Attention
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = inputs["input_ids"][:, 1:].contiguous()
losses = torch.nn.CrossEntropyLoss(reduction="none")(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
perplexity = float(torch.exp(losses.mean()).cpu())
attention_raw = min(perplexity / 30.0, 1.0)
# 2. Token diversity → Language
ids = inputs["input_ids"][0].cpu().tolist()
language_raw = len(set(ids)) / max(len(ids), 1)
# 3. Hidden state variance → Emotion
hn = hidden.squeeze().cpu().float().numpy()
norms = np.linalg.norm(hn, axis=1)
emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8))
# 4. Specificity markers → Visual
tl = text.lower()
nums = sum(c.isdigit() for c in text) / max(len(text), 1)
caps = sum(c.isupper() for c in text) / max(len(text), 1)
urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret",
"never", "always", "must", "urgent", "breaking", "exclusive", "free",
"fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl)
visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0)
# 5. Personal references → Default mode
words = tl.split()
personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"])
dm_raw = min(personal / max(len(words), 1) * 5, 1.0)
def sig(v, c=0.3, s=8.0):
return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c))))
att = sig(attention_raw, 0.25, 6.0)
emo = sig(emotion_raw, 0.15, 10.0)
lang = sig(language_raw, 0.5, 8.0)
vis = sig(visual_raw, 0.2, 8.0)
dm = sig(dm_raw, 0.2, 6.0)
overall = (att + emo + lang + vis + dm) / 5.0
viral = att * 0.4 + emo * 0.4 + vis * 0.2
torch.cuda.empty_cache()
return {
"overall_brain_engagement": round(overall, 1),
"viral_potential": round(viral, 1),
"attention_capture": round(att, 1),
"emotional_valence": round(emo, 1),
"language_processing": round(lang, 1),
"visual_imagery": round(vis, 1),
"hook_effectiveness": round(att, 1),
"retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1),
"_raw": {
"perplexity": round(perplexity, 2),
"token_diversity": round(language_raw, 3),
"hidden_variance": round(emotion_raw, 4),
"specificity": round(visual_raw, 3),
"personal_ref": round(dm_raw, 3),
},
}
# ---- Visualization ----
def _radar(scores, title="Brain Engagement"):
import matplotlib; matplotlib.use("Agg")
import matplotlib.pyplot as plt
cats = ["Attention", "Emotion", "Language", "Visual", "Viral"]
vals = [scores["attention_capture"], scores["emotional_valence"],
scores["language_processing"], scores["visual_imagery"],
scores["viral_potential"]]
vals += vals[:1]
angles = [n / 5.0 * 2 * np.pi for n in range(5)] + [0]
fig, ax = plt.subplots(figsize=(5, 5), subplot_kw=dict(polar=True))
fig.patch.set_facecolor("#0D1B2A")
ax.set_facecolor("#0D1B2A")
ax.plot(angles, vals, "o-", linewidth=2, color="#FFD166")
ax.fill(angles, vals, alpha=0.25, color="#FFD166")
ax.set_ylim(0, 100)
ax.set_xticks(angles[:-1])
ax.set_xticklabels(cats, size=11, color="white")
ax.set_yticks([25, 50, 75])
ax.set_yticklabels(["25", "50", "75"], size=8, color="grey")
ax.tick_params(colors="grey")
ax.spines["polar"].set_color("grey")
ax.grid(color="grey", alpha=0.3)
ax.set_title(title, size=14, color="white", pad=20)
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight", facecolor="#0D1B2A", dpi=100)
plt.close(fig)
buf.seek(0)
return buf
def _fmt(s):
return "\n".join([
f"🎯 Overall: {s['overall_brain_engagement']}/100",
f"⚡ Viral: {s['viral_potential']}/100",
f"🧠 Attention: {s['attention_capture']}/100",
f"❤️ Emotion: {s['emotional_valence']}/100",
f"💬 Language: {s['language_processing']}/100",
f"👁️ Visual: {s['visual_imagery']}/100",
f"🎣 Hook: {s['hook_effectiveness']}/100",
f"📈 Retention: {s['retention_prediction']}/100",
])
def _insight(s):
o = s["overall_brain_engagement"]
p = []
p.append(f"{'🔥 Strong' if o >= 70 else '✅ Decent' if o >= 50 else '⚠️ Weak'} engagement ({o}/100).")
if s["attention_capture"] >= 70: p.append("Great hook.")
elif s["attention_capture"] < 40: p.append("Needs stronger opening.")
if s["emotional_valence"] >= 70: p.append("Strong emotion.")
elif s["emotional_valence"] < 40: p.append("Add urgency or stakes.")
if s["hook_effectiveness"] >= 70 and s["retention_prediction"] < 50:
p.append("Hook is good but middle drops off.")
return " ".join(p)
# ---- Handlers ----
# @spaces.GPU # CPU mode
def _transcribe_and_score(video_path):
"""Extract audio, transcribe with Whisper, then score with Phi-2."""
import subprocess
# Extract audio
audio_path = os.path.join(os.path.dirname(video_path), "audio_extract.wav")
subprocess.run(["ffmpeg", "-i", video_path, "-vn", "-acodec", "pcm_s16le",
"-ar", "16000", "-ac", "1", audio_path, "-y"],
capture_output=True, timeout=60)
# Transcribe
import whisper
whisper_model = whisper.load_model("base", device="cpu")
result = whisper_model.transcribe(audio_path)
transcript = result["text"]
if os.path.exists(audio_path):
os.unlink(audio_path)
if not transcript or not transcript.strip():
raise ValueError("No speech detected in video")
# Score transcript using Phi-2
m = ensure_model()
tok = m["tokenizer"]
llm = m["model"].float() # CPU mode
inputs = tok(transcript, return_tensors="pt", truncation=True, max_length=512).to("cpu")
with torch.inference_mode():
outputs = llm(**inputs)
logits = outputs.logits
hidden = outputs.hidden_states[-1]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = inputs["input_ids"][:, 1:].contiguous()
losses = torch.nn.CrossEntropyLoss(reduction="none")(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
perplexity = float(torch.exp(losses.mean()).cpu())
attention_raw = min(perplexity / 30.0, 1.0)
ids = inputs["input_ids"][0].cpu().tolist()
language_raw = len(set(ids)) / max(len(ids), 1)
hn = hidden.squeeze().cpu().float().numpy()
norms = np.linalg.norm(hn, axis=1)
emotion_raw = float(np.std(norms) / (np.mean(norms) + 1e-8))
tl = transcript.lower()
nums = sum(c.isdigit() for c in transcript) / max(len(transcript), 1)
caps = sum(c.isupper() for c in transcript) / max(len(transcript), 1)
urgency = sum(1 for w in ["now", "shock", "destroy", "change", "secret",
"never", "always", "must", "urgent", "breaking", "exclusive", "free",
"fastest", "cheapest", "worst", "best", "insane", "crazy"] if w in tl)
visual_raw = min(nums * 10 + caps * 5 + urgency * 0.15, 1.0)
words = tl.split()
personal = sum(1 for w in words if w in ["i", "me", "my", "you", "your", "we", "our"])
dm_raw = min(personal / max(len(words), 1) * 5, 1.0)
def sig(v, c=0.3, s=8.0):
return float(100.0 / (1.0 + np.exp(-s * (max(0, min(1, v)) - c))))
att = sig(attention_raw, 0.25, 6.0)
emo = sig(emotion_raw, 0.15, 10.0)
lang = sig(language_raw, 0.5, 8.0)
vis = sig(visual_raw, 0.2, 8.0)
dm = sig(dm_raw, 0.2, 6.0)
overall = (att + emo + lang + vis + dm) / 5.0
viral = att * 0.4 + emo * 0.4 + vis * 0.2
torch.cuda.empty_cache()
return transcript, {
"overall_brain_engagement": round(overall, 1),
"viral_potential": round(viral, 1),
"attention_capture": round(att, 1),
"emotional_valence": round(emo, 1),
"language_processing": round(lang, 1),
"visual_imagery": round(vis, 1),
"hook_effectiveness": round(att, 1),
"retention_prediction": round(min(lang / max(att, 1) * 100, 100), 1),
}
def score_video_safe(video):
if video is None: return "Upload a video.", ""
try:
transcript, s = _transcribe_and_score(video)
preview = transcript[:300] + ("..." if len(transcript) > 300 else "")
return f"Transcript:\n{preview}\n\n{_fmt(s)}", _insight(s)
except Exception as e:
import traceback
return f"Error: {e}\n{traceback.format_exc()}", ""
def score_text_with_chart(text):
if not text or not text.strip(): return "Enter text.", None, ""
try:
s = _predict(text.strip())
return _fmt(s), _radar(s), _insight(s)
except Exception as e:
import traceback
return f"Error: {e}\n{traceback.format_exc()}", None, ""
def score_text_safe(text):
if not text or not text.strip(): return "Enter text.", ""
try:
s = _predict(text.strip())
return _fmt(s), _insight(s)
except Exception as e:
import traceback
return f"Error: {e}\n{traceback.format_exc()}", ""
def ab_test_safe(a, b):
if not a or not b: return "Enter both versions."
try:
sa, sb = _predict(a.strip()), _predict(b.strip())
va, vb = sa["viral_potential"], sb["viral_potential"]
w = f"🏆 A wins ({va} vs {vb})" if va > vb else (
f"🏆 B wins ({vb} vs {va})" if vb > va else "🤝 Tie")
return f"{w}\n\n--- Version A ---\n{_fmt(sa)}\n{_insight(sa)}\n\n--- Version B ---\n{_fmt(sb)}\n{_insight(sb)}"
except Exception as e:
return f"Error: {e}"
def api_json(text):
if not text: return '{"error":"No text"}'
try:
s = _predict(text.strip())
return json.dumps({"scores": s, "raw": s.pop("_raw", {})}, indent=2)
except Exception as e:
return json.dumps({"error": str(e)})
# ---- UI ----
with gr.Blocks(title="TRIBE V2 Brain Prediction", theme=gr.themes.Base(
primary_hue="amber", secondary_hue="cyan", neutral_hue="slate",
font=gr.themes.GoogleFont("Inter"),
)) as demo:
gr.Markdown("# 🧠 TRIBE V2 — Brain Response Prediction\n"
"Neuroscience-informed engagement scoring for your content.\n")
with gr.Tab("📝 Text"):
t_in = gr.Textbox(label="Content", lines=5, placeholder="Paste script or hook...")
t_btn = gr.Button("🧠 Analyze", variant="primary")
t_out = gr.Textbox(label="Scores", lines=10)
t_ins = gr.Textbox(label="💡 Insight")
t_btn.click(score_text_safe, [t_in], [t_out, t_ins], api_name="predict")
with gr.Tab("🎬 Video"):
gr.Markdown("Upload a video — audio is transcribed and scored. ~45-90s on CPU (no quota limit).")
v_in = gr.Video(label="Upload Video")
v_btn = gr.Button("🧠 Analyze Video", variant="primary")
v_out = gr.Textbox(label="Scores", lines=12)
v_ins = gr.Textbox(label="💡 Insight")
v_btn.click(score_video_safe, [v_in], [v_out, v_ins], api_name="predict_video")
with gr.Tab("⚔️ A/B Test"):
with gr.Row():
a_in = gr.Textbox(label="Version A", lines=3)
b_in = gr.Textbox(label="Version B", lines=3)
ab_btn = gr.Button("⚔️ Compare", variant="primary")
ab_out = gr.Textbox(label="Result", lines=12)
ab_btn.click(ab_test_safe, [a_in, b_in], [ab_out], api_name="ab_test")
with gr.Tab("🔌 API"):
gr.Markdown("Returns JSON for programmatic use.")
api_in = gr.Textbox(label="Text", lines=3)
api_btn = gr.Button("Get JSON")
api_out = gr.Textbox(label="JSON", lines=15)
api_btn.click(api_json, [api_in], [api_out], api_name="api_predict")
gr.Markdown("---\n*Powered by [Meta TRIBE V2](https://github.com/facebookresearch/tribev2) methodology | "
"CPU Basic (unlimited) | somebeast*")
demo.queue().launch()