try: import spaces except ImportError: class _Placeholder: def GPU(self, *a, **kw): def dec(fn): return fn return dec spaces = _Placeholder() import gradio as gr import numpy as np import plotly.graph_objects as go import threading from tribev2 import TribeModel # ── Load model (thread-safe, lazy) ────────────────────────────────────────── _model = None _model_lock = threading.Lock() def get_model(): global _model with _model_lock: if _model is None: _model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="./cache") return _model # fsaverage5 cortical region mapping (approximate vertex index ranges) ZONE_VERTEX_RANGES = { "Frontal": (0, 3500), "Action": (3500, 6500), "Attention": (6500, 10000), "Speech": (10000, 14500), "Visual": (14500, 20484), } ZONE_META = { "Frontal": {"desc": "Meaning, intent, narrative focus", "color": "#D4A017", "icon": "01"}, "Action": {"desc": "Movement, gestures, body dynamics", "color": "#E8651A", "icon": "02"}, "Attention": {"desc": "Gaze direction, spatial salience", "color": "#22C55E", "icon": "03"}, "Speech": {"desc": "Voice, faces, object recognition", "color": "#EF4444", "icon": "04"}, "Visual": {"desc": "Shape, contrast, color, motion", "color": "#3B82F6", "icon": "05"}, } # ── Analysis using TRIBE v2 ──────────────────────────────────────────────── def normalize(arr): mn, mx = arr.min(), arr.max() if mx - mn < 1e-9: return np.full_like(arr, 50.0) return (arr - mn) / (mx - mn) * 100 @spaces.GPU def analyze_video(video_path): m = get_model() df = m.get_events_dataframe(video_path=video_path) preds, segments = m.predict(events=df) n_steps = preds.shape[0] if n_steps < 2: raise gr.Error("Video too short to analyze.") if hasattr(segments, "start_time"): timestamps = [float(s.start_time) for s in segments] else: timestamps = list(np.arange(n_steps) * 1.0) duration = timestamps[-1] if timestamps else n_steps zones = {} for zone_name, (start, end) in ZONE_VERTEX_RANGES.items(): end = min(end, preds.shape[1]) zone_signal = np.mean(preds[:, start:end], axis=1) zones[zone_name] = normalize(zone_signal) engagement = normalize( 0.25 * zones["Frontal"] + 0.20 * zones["Action"] + 0.25 * zones["Attention"] + 0.15 * zones["Speech"] + 0.15 * zones["Visual"] ) return timestamps, engagement, zones, duration # ── Chart ─────────────────────────────────────────────────────────────────── def make_engagement_chart(timestamps, engagement): fig = go.Figure() # Subtle grid area fig.add_trace(go.Scatter( x=timestamps, y=engagement, mode="lines", line=dict(color="#D4A017", width=2.5), fill="tozeroy", fillcolor="rgba(212,160,23,0.06)", hovertemplate="%{x:.1f}s
Response: %{y:.0f}/100", )) fig.update_layout( template="plotly_dark", paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", margin=dict(l=45, r=15, t=15, b=40), height=260, xaxis=dict( title="", showgrid=True, gridcolor="rgba(255,255,255,0.04)", zeroline=False, color="#666", tickfont=dict(size=11, family="DM Mono, monospace"), ), yaxis=dict( title="", showgrid=True, gridcolor="rgba(255,255,255,0.04)", zeroline=False, range=[0, 105], color="#666", tickfont=dict(size=11, family="DM Mono, monospace"), ), hoverlabel=dict( bgcolor="#1a1a1a", bordercolor="#D4A017", font=dict(color="#fff", family="DM Mono, monospace", size=12), ), ) return fig # ── Feedback generation ───────────────────────────────────────────────────── def generate_feedback(engagement, zones): avg = float(np.mean(engagement)) mx = float(np.max(engagement)) mn = float(np.min(engagement)) weak_mask = engagement < (avg * 0.65) weak_pct = np.sum(weak_mask) / len(engagement) * 100 hook_len = max(1, len(engagement) // 10) hook_avg = float(np.mean(engagement[:hook_len])) end_avg = float(np.mean(engagement[-hook_len:])) if avg >= 72: strength, strength_color = "STRONG", "#22C55E" elif avg >= 50: strength, strength_color = "AVERAGE", "#D4A017" else: strength, strength_color = "WEAK", "#EF4444" zone_avgs = {name: float(np.mean(sig)) for name, sig in zones.items()} weakest_zone = min(zone_avgs, key=zone_avgs.get) recs = [] if hook_avg < 55: recs.append("**Weak hook** — first seconds don't grab. Lead with your strongest visual or a surprise.") if end_avg < 45: recs.append("**Flat ending** — viewers drop off. Add a payoff or cliffhanger in the final seconds.") if weak_pct > 30: recs.append(f"**{weak_pct:.0f}% dead air** — too many flat stretches. Cut or accelerate the lows.") zone_recs = { "Action": "More movement, gestures, or physical dynamics.", "Attention": "Vary shots — add cuts, zooms, angle changes.", "Speech": "More face time or recognizable elements.", "Visual": "Flat texture. Try contrast, color grading, composition.", "Frontal": "No clear thread. Give viewers something to follow.", } if zone_avgs[weakest_zone] < 55: recs.append(f"**{weakest_zone} zone is dragging** — {zone_recs.get(weakest_zone, '')}") if not recs: recs.append("Solid cut. Minor gains from tightening the middle third.") rec_lines = "\n".join(f"- {r}" for r in recs) return zone_avgs, strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct, rec_lines # ── HTML builders ─────────────────────────────────────────────────────────── def make_stats_html(strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct): return f"""
Verdict
{strength}
Average
{avg:.0f}
Peak
{mx:.0f}
Floor
{mn:.0f}
Hook
{hook_avg:.0f}
Dead air
{weak_pct:.0f}%
""" def make_zone_html(zone_avgs): cards = "" for zone_name, meta in ZONE_META.items(): val = zone_avgs.get(zone_name, 0) color = meta["color"] idx = meta["icon"] if val >= 70: activity = "HIGH" elif val >= 40: activity = "MID" else: activity = "LOW" cards += f"""
{idx} {zone_name} {activity}
{meta['desc']}
{val:.0f}
""" return f'
{cards}
' # ── Main handlers ─────────────────────────────────────────────────────────── def process_video(video): if video is None: raise gr.Error("Upload a video first.") timestamps, engagement, zones, duration = analyze_video(video) chart = make_engagement_chart(timestamps, engagement) zone_avgs, strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct, recs = generate_feedback(engagement, zones) stats_html = make_stats_html(strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct) zone_html = make_zone_html(zone_avgs) return chart, stats_html, zone_html, recs def process_comparison(video_a, video_b): if video_a is None or video_b is None: raise gr.Error("Upload both videos to compare.") ts_a, eng_a, zones_a, dur_a = analyze_video(video_a) ts_b, eng_b, zones_b, dur_b = analyze_video(video_b) fig = go.Figure() fig.add_trace(go.Scatter( x=ts_a, y=eng_a, mode="lines", name="Video A", line=dict(color="#D4A017", width=2.5), )) fig.add_trace(go.Scatter( x=ts_b, y=eng_b, mode="lines", name="Video B", line=dict(color="#3B82F6", width=2.5), )) fig.update_layout( template="plotly_dark", paper_bgcolor="rgba(0,0,0,0)", plot_bgcolor="rgba(0,0,0,0)", margin=dict(l=45, r=15, t=15, b=40), height=300, xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.04)", zeroline=False, color="#666", tickfont=dict(size=11, family="DM Mono, monospace")), yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.04)", zeroline=False, range=[0, 105], color="#666", tickfont=dict(size=11, family="DM Mono, monospace")), legend=dict(font=dict(color="#999", family="DM Mono, monospace", size=11), bgcolor="rgba(0,0,0,0)", orientation="h", y=1.08), hoverlabel=dict(bgcolor="#1a1a1a", bordercolor="#D4A017", font=dict(color="#fff", family="DM Mono, monospace", size=12)), ) avg_a, avg_b = np.mean(eng_a), np.mean(eng_b) winner = "A" if avg_a > avg_b else "B" diff = abs(avg_a - avg_b) summary = f"""| | Video A | Video B | |--|---------|---------| | **Average** | {avg_a:.0f} | {avg_b:.0f} | | **Peak** | {np.max(eng_a):.0f} | {np.max(eng_b):.0f} | | **Floor** | {np.min(eng_a):.0f} | {np.min(eng_b):.0f} | | **Duration** | {dur_a:.1f}s | {dur_b:.1f}s | **Video {winner}** wins by **+{diff:.0f}** avg response. """ return fig, summary # ── Custom CSS ────────────────────────────────────────────────────────────── CSS = """ @import url('https://fonts.googleapis.com/css2?family=DM+Mono:wght@300;400;500&family=Syne:wght@400;600;700;800&display=swap'); /* Global overrides */ .gradio-container { max-width: 900px !important; margin: 0 auto !important; font-family: 'Syne', sans-serif !important; background: #0A0A0B !important; } .dark .gradio-container { background: #0A0A0B !important; } /* Video upload area */ .video-upload { background: rgba(255,255,255,0.03) !important; border: 1px solid rgba(255,255,255,0.08) !important; border-radius: 14px !important; overflow: hidden; } .video-upload .wrap, .video-upload video { border-radius: 14px !important; } /* Header */ .hero-header { text-align: center; padding: 48px 20px 32px; position: relative; } .hero-header::before { content: ''; position: absolute; top: 0; left: 50%; transform: translateX(-50%); width: 400px; height: 400px; background: radial-gradient(circle, rgba(212,160,23,0.08) 0%, transparent 70%); pointer-events: none; } .hero-title { font-family: 'Syne', sans-serif; font-size: 56px; font-weight: 800; letter-spacing: -2px; color: #FAFAFA; margin: 0; line-height: 1; } .hero-title span { color: #D4A017; } .hero-sub { font-family: 'DM Mono', monospace; font-size: 13px; color: #666; margin-top: 12px; letter-spacing: 0.5px; line-height: 1.6; max-width: 600px; margin-left: auto; margin-right: auto; } /* Tabs */ .tab-nav { border: none !important; justify-content: center !important; } .tab-nav button { font-family: 'DM Mono', monospace !important; font-size: 12px !important; letter-spacing: 1px !important; text-transform: uppercase !important; color: #555 !important; border: none !important; background: transparent !important; padding: 10px 20px !important; } .tab-nav button.selected { color: #D4A017 !important; border-bottom: 2px solid #D4A017 !important; background: transparent !important; } /* Stat cards */ .stats-grid { display: grid; grid-template-columns: repeat(6, 1fr); gap: 10px; margin-top: 8px; } .stat-card { background: rgba(255,255,255,0.03); border: 1px solid rgba(255,255,255,0.06); border-radius: 10px; padding: 14px 16px; text-align: center; } .stat-card.stat-verdict { border-width: 2px; } .stat-label { font-family: 'DM Mono', monospace; font-size: 10px; text-transform: uppercase; letter-spacing: 1.5px; color: #555; margin-bottom: 6px; } .stat-value { font-family: 'Syne', sans-serif; font-size: 22px; font-weight: 700; color: #FAFAFA; } /* Zone cards */ .zone-stack { display: flex; flex-direction: column; gap: 8px; } .zone-card { background: rgba(255,255,255,0.02); border: 1px solid rgba(255,255,255,0.06); border-radius: 10px; padding: 14px 16px; } .zone-header { display: flex; align-items: center; gap: 10px; margin-bottom: 4px; } .zone-idx { font-family: 'DM Mono', monospace; font-size: 11px; font-weight: 500; opacity: 0.7; } .zone-name { font-family: 'Syne', sans-serif; font-size: 14px; font-weight: 600; color: #E0E0E0; flex: 1; } .zone-activity { font-family: 'DM Mono', monospace; font-size: 10px; letter-spacing: 1px; font-weight: 500; } .zone-desc { font-family: 'DM Mono', monospace; font-size: 11px; color: #555; margin-bottom: 10px; line-height: 1.4; } .zone-bar-track { height: 4px; background: rgba(255,255,255,0.06); border-radius: 2px; overflow: hidden; margin-bottom: 6px; } .zone-bar-fill { height: 100%; border-radius: 2px; transition: width 0.8s cubic-bezier(0.22, 1, 0.36, 1); } .zone-val { font-family: 'DM Mono', monospace; font-size: 12px; font-weight: 500; text-align: right; } /* Plot container */ .plot-container { border-radius: 12px; overflow: hidden; } /* Blocks overrides */ .block { background: transparent !important; border: none !important; box-shadow: none !important; } .label-wrap { display: none !important; } .prose { color: #999 !important; } .prose h3 { color: #E0E0E0 !important; font-family: 'Syne', sans-serif !important; } .prose strong { color: #FAFAFA !important; } .prose table { border-color: rgba(255,255,255,0.08) !important; } .prose th, .prose td { border-color: rgba(255,255,255,0.08) !important; color: #999 !important; } .prose th { color: #ccc !important; } .prose a { color: #D4A017 !important; } /* Video component */ .video-container { border-radius: 12px; overflow: hidden; } /* Button */ button.primary { background: #D4A017 !important; border: none !important; color: #0A0A0B !important; font-family: 'Syne', sans-serif !important; font-weight: 700 !important; letter-spacing: 0.5px !important; border-radius: 10px !important; } button.primary:hover { background: #E8B12A !important; } /* Markdown feedback section */ .feedback-section .prose { font-family: 'DM Mono', monospace !important; font-size: 13px !important; line-height: 1.7 !important; } /* Responsive */ @media (max-width: 768px) { .stats-grid { grid-template-columns: repeat(3, 1fr); } .hero-title { font-size: 36px; } } """ # ── UI ────────────────────────────────────────────────────────────────────── theme = gr.themes.Base( primary_hue=gr.themes.colors.amber, neutral_hue=gr.themes.colors.zinc, ).set( body_background_fill="#0A0A0B", body_text_color="#999999", block_background_fill="transparent", block_border_width="0px", block_shadow="none", input_background_fill="#111113", input_border_color="rgba(255,255,255,0.08)", input_border_width="1px", button_primary_background_fill="#D4A017", button_primary_text_color="#0A0A0B", ) with gr.Blocks(theme=theme, css=CSS, title="TRIBE2") as demo: gr.HTML("""

TRIBE2

Predict where your video loses the viewer. Powered by Meta's brain-encoding model — maps cortical response across 20,000 vertices in real time.

""") with gr.Tab("Analyze"): video_input = gr.Video(label="", height=360, elem_classes=["video-upload"]) analyze_btn = gr.Button("Run analysis", variant="primary", size="lg") stats_html = gr.HTML() chart = gr.Plot() with gr.Row(): with gr.Column(scale=2, elem_classes=["feedback-section"]): feedback = gr.Markdown() with gr.Column(scale=1): zone_html = gr.HTML() analyze_btn.click( fn=process_video, inputs=[video_input], outputs=[chart, stats_html, zone_html, feedback], ) with gr.Tab("A/B Compare"): with gr.Row(): vid_a = gr.Video(label="Video A", elem_classes=["video-upload"]) vid_b = gr.Video(label="Video B", elem_classes=["video-upload"]) compare_btn = gr.Button("Compare", variant="primary", size="lg") compare_chart = gr.Plot() compare_md = gr.Markdown() compare_btn.click( fn=process_comparison, inputs=[vid_a, vid_b], outputs=[compare_chart, compare_md], ) with gr.Tab("How it works"): gr.Markdown("""### The model **TRIBE v2** is Meta's brain-encoding foundation model. It predicts fMRI-level cortical activation from video, audio, and text using three extractors: V-JEPA2 (vision) + Wav2Vec-BERT 2.0 (audio) + LLaMA 3.2 (language) The output is a prediction across ~20,000 cortical vertices on the fsaverage5 surface. We aggregate those into five zones and derive a composite engagement signal. Predictions carry a ~5s hemodynamic offset inherent to fMRI. Treat each timestamp as an approximate editing window. [Model card](https://huggingface.co/facebook/tribev2) """) if __name__ == "__main__": demo.launch()