tribe2 / app.py
hmb's picture
hmb HF Staff
Centre layout (900px max), add video upload background, centre tabs
99b4437
try:
import spaces
except ImportError:
class _Placeholder:
def GPU(self, *a, **kw):
def dec(fn): return fn
return dec
spaces = _Placeholder()
import gradio as gr
import numpy as np
import plotly.graph_objects as go
import threading
from tribev2 import TribeModel
# ── Load model (thread-safe, lazy) ──────────────────────────────────────────
_model = None
_model_lock = threading.Lock()
def get_model():
global _model
with _model_lock:
if _model is None:
_model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="./cache")
return _model
# fsaverage5 cortical region mapping (approximate vertex index ranges)
ZONE_VERTEX_RANGES = {
"Frontal": (0, 3500),
"Action": (3500, 6500),
"Attention": (6500, 10000),
"Speech": (10000, 14500),
"Visual": (14500, 20484),
}
ZONE_META = {
"Frontal": {"desc": "Meaning, intent, narrative focus", "color": "#D4A017", "icon": "01"},
"Action": {"desc": "Movement, gestures, body dynamics", "color": "#E8651A", "icon": "02"},
"Attention": {"desc": "Gaze direction, spatial salience", "color": "#22C55E", "icon": "03"},
"Speech": {"desc": "Voice, faces, object recognition", "color": "#EF4444", "icon": "04"},
"Visual": {"desc": "Shape, contrast, color, motion", "color": "#3B82F6", "icon": "05"},
}
# ── Analysis using TRIBE v2 ────────────────────────────────────────────────
def normalize(arr):
mn, mx = arr.min(), arr.max()
if mx - mn < 1e-9:
return np.full_like(arr, 50.0)
return (arr - mn) / (mx - mn) * 100
@spaces.GPU
def analyze_video(video_path):
m = get_model()
df = m.get_events_dataframe(video_path=video_path)
preds, segments = m.predict(events=df)
n_steps = preds.shape[0]
if n_steps < 2:
raise gr.Error("Video too short to analyze.")
if hasattr(segments, "start_time"):
timestamps = [float(s.start_time) for s in segments]
else:
timestamps = list(np.arange(n_steps) * 1.0)
duration = timestamps[-1] if timestamps else n_steps
zones = {}
for zone_name, (start, end) in ZONE_VERTEX_RANGES.items():
end = min(end, preds.shape[1])
zone_signal = np.mean(preds[:, start:end], axis=1)
zones[zone_name] = normalize(zone_signal)
engagement = normalize(
0.25 * zones["Frontal"]
+ 0.20 * zones["Action"]
+ 0.25 * zones["Attention"]
+ 0.15 * zones["Speech"]
+ 0.15 * zones["Visual"]
)
return timestamps, engagement, zones, duration
# ── Chart ───────────────────────────────────────────────────────────────────
def make_engagement_chart(timestamps, engagement):
fig = go.Figure()
# Subtle grid area
fig.add_trace(go.Scatter(
x=timestamps, y=engagement,
mode="lines",
line=dict(color="#D4A017", width=2.5),
fill="tozeroy",
fillcolor="rgba(212,160,23,0.06)",
hovertemplate="<b>%{x:.1f}s</b><br>Response: %{y:.0f}/100<extra></extra>",
))
fig.update_layout(
template="plotly_dark",
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
margin=dict(l=45, r=15, t=15, b=40),
height=260,
xaxis=dict(
title="",
showgrid=True,
gridcolor="rgba(255,255,255,0.04)",
zeroline=False,
color="#666",
tickfont=dict(size=11, family="DM Mono, monospace"),
),
yaxis=dict(
title="",
showgrid=True,
gridcolor="rgba(255,255,255,0.04)",
zeroline=False,
range=[0, 105],
color="#666",
tickfont=dict(size=11, family="DM Mono, monospace"),
),
hoverlabel=dict(
bgcolor="#1a1a1a",
bordercolor="#D4A017",
font=dict(color="#fff", family="DM Mono, monospace", size=12),
),
)
return fig
# ── Feedback generation ─────────────────────────────────────────────────────
def generate_feedback(engagement, zones):
avg = float(np.mean(engagement))
mx = float(np.max(engagement))
mn = float(np.min(engagement))
weak_mask = engagement < (avg * 0.65)
weak_pct = np.sum(weak_mask) / len(engagement) * 100
hook_len = max(1, len(engagement) // 10)
hook_avg = float(np.mean(engagement[:hook_len]))
end_avg = float(np.mean(engagement[-hook_len:]))
if avg >= 72:
strength, strength_color = "STRONG", "#22C55E"
elif avg >= 50:
strength, strength_color = "AVERAGE", "#D4A017"
else:
strength, strength_color = "WEAK", "#EF4444"
zone_avgs = {name: float(np.mean(sig)) for name, sig in zones.items()}
weakest_zone = min(zone_avgs, key=zone_avgs.get)
recs = []
if hook_avg < 55:
recs.append("**Weak hook** β€” first seconds don't grab. Lead with your strongest visual or a surprise.")
if end_avg < 45:
recs.append("**Flat ending** β€” viewers drop off. Add a payoff or cliffhanger in the final seconds.")
if weak_pct > 30:
recs.append(f"**{weak_pct:.0f}% dead air** β€” too many flat stretches. Cut or accelerate the lows.")
zone_recs = {
"Action": "More movement, gestures, or physical dynamics.",
"Attention": "Vary shots β€” add cuts, zooms, angle changes.",
"Speech": "More face time or recognizable elements.",
"Visual": "Flat texture. Try contrast, color grading, composition.",
"Frontal": "No clear thread. Give viewers something to follow.",
}
if zone_avgs[weakest_zone] < 55:
recs.append(f"**{weakest_zone} zone is dragging** β€” {zone_recs.get(weakest_zone, '')}")
if not recs:
recs.append("Solid cut. Minor gains from tightening the middle third.")
rec_lines = "\n".join(f"- {r}" for r in recs)
return zone_avgs, strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct, rec_lines
# ── HTML builders ───────────────────────────────────────────────────────────
def make_stats_html(strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct):
return f"""
<div class="stats-grid">
<div class="stat-card stat-verdict" style="border-color: {strength_color}40;">
<div class="stat-label">Verdict</div>
<div class="stat-value" style="color: {strength_color};">{strength}</div>
</div>
<div class="stat-card">
<div class="stat-label">Average</div>
<div class="stat-value">{avg:.0f}</div>
</div>
<div class="stat-card">
<div class="stat-label">Peak</div>
<div class="stat-value">{mx:.0f}</div>
</div>
<div class="stat-card">
<div class="stat-label">Floor</div>
<div class="stat-value">{mn:.0f}</div>
</div>
<div class="stat-card">
<div class="stat-label">Hook</div>
<div class="stat-value">{hook_avg:.0f}</div>
</div>
<div class="stat-card">
<div class="stat-label">Dead air</div>
<div class="stat-value">{weak_pct:.0f}%</div>
</div>
</div>
"""
def make_zone_html(zone_avgs):
cards = ""
for zone_name, meta in ZONE_META.items():
val = zone_avgs.get(zone_name, 0)
color = meta["color"]
idx = meta["icon"]
if val >= 70:
activity = "HIGH"
elif val >= 40:
activity = "MID"
else:
activity = "LOW"
cards += f"""
<div class="zone-card">
<div class="zone-header">
<span class="zone-idx" style="color: {color};">{idx}</span>
<span class="zone-name">{zone_name}</span>
<span class="zone-activity" style="color: {color};">{activity}</span>
</div>
<div class="zone-desc">{meta['desc']}</div>
<div class="zone-bar-track">
<div class="zone-bar-fill" style="width:{val:.0f}%;background:{color};"></div>
</div>
<div class="zone-val" style="color: {color};">{val:.0f}</div>
</div>"""
return f'<div class="zone-stack">{cards}</div>'
# ── Main handlers ───────────────────────────────────────────────────────────
def process_video(video):
if video is None:
raise gr.Error("Upload a video first.")
timestamps, engagement, zones, duration = analyze_video(video)
chart = make_engagement_chart(timestamps, engagement)
zone_avgs, strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct, recs = generate_feedback(engagement, zones)
stats_html = make_stats_html(strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct)
zone_html = make_zone_html(zone_avgs)
return chart, stats_html, zone_html, recs
def process_comparison(video_a, video_b):
if video_a is None or video_b is None:
raise gr.Error("Upload both videos to compare.")
ts_a, eng_a, zones_a, dur_a = analyze_video(video_a)
ts_b, eng_b, zones_b, dur_b = analyze_video(video_b)
fig = go.Figure()
fig.add_trace(go.Scatter(
x=ts_a, y=eng_a, mode="lines", name="Video A",
line=dict(color="#D4A017", width=2.5),
))
fig.add_trace(go.Scatter(
x=ts_b, y=eng_b, mode="lines", name="Video B",
line=dict(color="#3B82F6", width=2.5),
))
fig.update_layout(
template="plotly_dark",
paper_bgcolor="rgba(0,0,0,0)",
plot_bgcolor="rgba(0,0,0,0)",
margin=dict(l=45, r=15, t=15, b=40),
height=300,
xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.04)", zeroline=False, color="#666",
tickfont=dict(size=11, family="DM Mono, monospace")),
yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.04)", zeroline=False, range=[0, 105], color="#666",
tickfont=dict(size=11, family="DM Mono, monospace")),
legend=dict(font=dict(color="#999", family="DM Mono, monospace", size=11),
bgcolor="rgba(0,0,0,0)", orientation="h", y=1.08),
hoverlabel=dict(bgcolor="#1a1a1a", bordercolor="#D4A017",
font=dict(color="#fff", family="DM Mono, monospace", size=12)),
)
avg_a, avg_b = np.mean(eng_a), np.mean(eng_b)
winner = "A" if avg_a > avg_b else "B"
diff = abs(avg_a - avg_b)
summary = f"""| | Video A | Video B |
|--|---------|---------|
| **Average** | {avg_a:.0f} | {avg_b:.0f} |
| **Peak** | {np.max(eng_a):.0f} | {np.max(eng_b):.0f} |
| **Floor** | {np.min(eng_a):.0f} | {np.min(eng_b):.0f} |
| **Duration** | {dur_a:.1f}s | {dur_b:.1f}s |
**Video {winner}** wins by **+{diff:.0f}** avg response.
"""
return fig, summary
# ── Custom CSS ──────────────────────────────────────────────────────────────
CSS = """
@import url('https://fonts.googleapis.com/css2?family=DM+Mono:wght@300;400;500&family=Syne:wght@400;600;700;800&display=swap');
/* Global overrides */
.gradio-container {
max-width: 900px !important;
margin: 0 auto !important;
font-family: 'Syne', sans-serif !important;
background: #0A0A0B !important;
}
.dark .gradio-container { background: #0A0A0B !important; }
/* Video upload area */
.video-upload {
background: rgba(255,255,255,0.03) !important;
border: 1px solid rgba(255,255,255,0.08) !important;
border-radius: 14px !important;
overflow: hidden;
}
.video-upload .wrap, .video-upload video {
border-radius: 14px !important;
}
/* Header */
.hero-header {
text-align: center;
padding: 48px 20px 32px;
position: relative;
}
.hero-header::before {
content: '';
position: absolute;
top: 0; left: 50%;
transform: translateX(-50%);
width: 400px; height: 400px;
background: radial-gradient(circle, rgba(212,160,23,0.08) 0%, transparent 70%);
pointer-events: none;
}
.hero-title {
font-family: 'Syne', sans-serif;
font-size: 56px;
font-weight: 800;
letter-spacing: -2px;
color: #FAFAFA;
margin: 0;
line-height: 1;
}
.hero-title span { color: #D4A017; }
.hero-sub {
font-family: 'DM Mono', monospace;
font-size: 13px;
color: #666;
margin-top: 12px;
letter-spacing: 0.5px;
line-height: 1.6;
max-width: 600px;
margin-left: auto;
margin-right: auto;
}
/* Tabs */
.tab-nav { border: none !important; justify-content: center !important; }
.tab-nav button {
font-family: 'DM Mono', monospace !important;
font-size: 12px !important;
letter-spacing: 1px !important;
text-transform: uppercase !important;
color: #555 !important;
border: none !important;
background: transparent !important;
padding: 10px 20px !important;
}
.tab-nav button.selected {
color: #D4A017 !important;
border-bottom: 2px solid #D4A017 !important;
background: transparent !important;
}
/* Stat cards */
.stats-grid {
display: grid;
grid-template-columns: repeat(6, 1fr);
gap: 10px;
margin-top: 8px;
}
.stat-card {
background: rgba(255,255,255,0.03);
border: 1px solid rgba(255,255,255,0.06);
border-radius: 10px;
padding: 14px 16px;
text-align: center;
}
.stat-card.stat-verdict {
border-width: 2px;
}
.stat-label {
font-family: 'DM Mono', monospace;
font-size: 10px;
text-transform: uppercase;
letter-spacing: 1.5px;
color: #555;
margin-bottom: 6px;
}
.stat-value {
font-family: 'Syne', sans-serif;
font-size: 22px;
font-weight: 700;
color: #FAFAFA;
}
/* Zone cards */
.zone-stack {
display: flex;
flex-direction: column;
gap: 8px;
}
.zone-card {
background: rgba(255,255,255,0.02);
border: 1px solid rgba(255,255,255,0.06);
border-radius: 10px;
padding: 14px 16px;
}
.zone-header {
display: flex;
align-items: center;
gap: 10px;
margin-bottom: 4px;
}
.zone-idx {
font-family: 'DM Mono', monospace;
font-size: 11px;
font-weight: 500;
opacity: 0.7;
}
.zone-name {
font-family: 'Syne', sans-serif;
font-size: 14px;
font-weight: 600;
color: #E0E0E0;
flex: 1;
}
.zone-activity {
font-family: 'DM Mono', monospace;
font-size: 10px;
letter-spacing: 1px;
font-weight: 500;
}
.zone-desc {
font-family: 'DM Mono', monospace;
font-size: 11px;
color: #555;
margin-bottom: 10px;
line-height: 1.4;
}
.zone-bar-track {
height: 4px;
background: rgba(255,255,255,0.06);
border-radius: 2px;
overflow: hidden;
margin-bottom: 6px;
}
.zone-bar-fill {
height: 100%;
border-radius: 2px;
transition: width 0.8s cubic-bezier(0.22, 1, 0.36, 1);
}
.zone-val {
font-family: 'DM Mono', monospace;
font-size: 12px;
font-weight: 500;
text-align: right;
}
/* Plot container */
.plot-container {
border-radius: 12px;
overflow: hidden;
}
/* Blocks overrides */
.block { background: transparent !important; border: none !important; box-shadow: none !important; }
.label-wrap { display: none !important; }
.prose { color: #999 !important; }
.prose h3 { color: #E0E0E0 !important; font-family: 'Syne', sans-serif !important; }
.prose strong { color: #FAFAFA !important; }
.prose table { border-color: rgba(255,255,255,0.08) !important; }
.prose th, .prose td { border-color: rgba(255,255,255,0.08) !important; color: #999 !important; }
.prose th { color: #ccc !important; }
.prose a { color: #D4A017 !important; }
/* Video component */
.video-container { border-radius: 12px; overflow: hidden; }
/* Button */
button.primary {
background: #D4A017 !important;
border: none !important;
color: #0A0A0B !important;
font-family: 'Syne', sans-serif !important;
font-weight: 700 !important;
letter-spacing: 0.5px !important;
border-radius: 10px !important;
}
button.primary:hover {
background: #E8B12A !important;
}
/* Markdown feedback section */
.feedback-section .prose {
font-family: 'DM Mono', monospace !important;
font-size: 13px !important;
line-height: 1.7 !important;
}
/* Responsive */
@media (max-width: 768px) {
.stats-grid { grid-template-columns: repeat(3, 1fr); }
.hero-title { font-size: 36px; }
}
"""
# ── UI ──────────────────────────────────────────────────────────────────────
theme = gr.themes.Base(
primary_hue=gr.themes.colors.amber,
neutral_hue=gr.themes.colors.zinc,
).set(
body_background_fill="#0A0A0B",
body_text_color="#999999",
block_background_fill="transparent",
block_border_width="0px",
block_shadow="none",
input_background_fill="#111113",
input_border_color="rgba(255,255,255,0.08)",
input_border_width="1px",
button_primary_background_fill="#D4A017",
button_primary_text_color="#0A0A0B",
)
with gr.Blocks(theme=theme, css=CSS, title="TRIBE2") as demo:
gr.HTML("""
<div class="hero-header">
<h1 class="hero-title">TRIBE<span>2</span></h1>
<p class="hero-sub">
Predict where your video loses the viewer. Powered by Meta's brain-encoding
model β€” maps cortical response across 20,000 vertices in real time.
</p>
</div>
""")
with gr.Tab("Analyze"):
video_input = gr.Video(label="", height=360, elem_classes=["video-upload"])
analyze_btn = gr.Button("Run analysis", variant="primary", size="lg")
stats_html = gr.HTML()
chart = gr.Plot()
with gr.Row():
with gr.Column(scale=2, elem_classes=["feedback-section"]):
feedback = gr.Markdown()
with gr.Column(scale=1):
zone_html = gr.HTML()
analyze_btn.click(
fn=process_video,
inputs=[video_input],
outputs=[chart, stats_html, zone_html, feedback],
)
with gr.Tab("A/B Compare"):
with gr.Row():
vid_a = gr.Video(label="Video A", elem_classes=["video-upload"])
vid_b = gr.Video(label="Video B", elem_classes=["video-upload"])
compare_btn = gr.Button("Compare", variant="primary", size="lg")
compare_chart = gr.Plot()
compare_md = gr.Markdown()
compare_btn.click(
fn=process_comparison,
inputs=[vid_a, vid_b],
outputs=[compare_chart, compare_md],
)
with gr.Tab("How it works"):
gr.Markdown("""### The model
**TRIBE v2** is Meta's brain-encoding foundation model. It predicts fMRI-level cortical
activation from video, audio, and text using three extractors:
V-JEPA2 (vision) + Wav2Vec-BERT 2.0 (audio) + LLaMA 3.2 (language)
The output is a prediction across ~20,000 cortical vertices on the fsaverage5 surface.
We aggregate those into five zones and derive a composite engagement signal.
Predictions carry a ~5s hemodynamic offset inherent to fMRI.
Treat each timestamp as an approximate editing window.
[Model card](https://huggingface.co/facebook/tribev2)
""")
if __name__ == "__main__":
demo.launch()