| try: |
| import spaces |
| except ImportError: |
| class _Placeholder: |
| def GPU(self, *a, **kw): |
| def dec(fn): return fn |
| return dec |
| spaces = _Placeholder() |
|
|
| import gradio as gr |
| import numpy as np |
| import plotly.graph_objects as go |
| import threading |
| from tribev2 import TribeModel |
|
|
| |
|
|
| _model = None |
| _model_lock = threading.Lock() |
|
|
| def get_model(): |
| global _model |
| with _model_lock: |
| if _model is None: |
| _model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="./cache") |
| return _model |
|
|
| |
| ZONE_VERTEX_RANGES = { |
| "Frontal": (0, 3500), |
| "Action": (3500, 6500), |
| "Attention": (6500, 10000), |
| "Speech": (10000, 14500), |
| "Visual": (14500, 20484), |
| } |
|
|
| ZONE_META = { |
| "Frontal": {"desc": "Meaning, intent, narrative focus", "color": "#D4A017", "icon": "01"}, |
| "Action": {"desc": "Movement, gestures, body dynamics", "color": "#E8651A", "icon": "02"}, |
| "Attention": {"desc": "Gaze direction, spatial salience", "color": "#22C55E", "icon": "03"}, |
| "Speech": {"desc": "Voice, faces, object recognition", "color": "#EF4444", "icon": "04"}, |
| "Visual": {"desc": "Shape, contrast, color, motion", "color": "#3B82F6", "icon": "05"}, |
| } |
|
|
|
|
| |
|
|
| def normalize(arr): |
| mn, mx = arr.min(), arr.max() |
| if mx - mn < 1e-9: |
| return np.full_like(arr, 50.0) |
| return (arr - mn) / (mx - mn) * 100 |
|
|
|
|
| @spaces.GPU |
| def analyze_video(video_path): |
| m = get_model() |
| df = m.get_events_dataframe(video_path=video_path) |
| preds, segments = m.predict(events=df) |
|
|
| n_steps = preds.shape[0] |
| if n_steps < 2: |
| raise gr.Error("Video too short to analyze.") |
|
|
| if hasattr(segments, "start_time"): |
| timestamps = [float(s.start_time) for s in segments] |
| else: |
| timestamps = list(np.arange(n_steps) * 1.0) |
| duration = timestamps[-1] if timestamps else n_steps |
|
|
| zones = {} |
| for zone_name, (start, end) in ZONE_VERTEX_RANGES.items(): |
| end = min(end, preds.shape[1]) |
| zone_signal = np.mean(preds[:, start:end], axis=1) |
| zones[zone_name] = normalize(zone_signal) |
|
|
| engagement = normalize( |
| 0.25 * zones["Frontal"] |
| + 0.20 * zones["Action"] |
| + 0.25 * zones["Attention"] |
| + 0.15 * zones["Speech"] |
| + 0.15 * zones["Visual"] |
| ) |
|
|
| return timestamps, engagement, zones, duration |
|
|
|
|
| |
|
|
| def make_engagement_chart(timestamps, engagement): |
| fig = go.Figure() |
|
|
| |
| fig.add_trace(go.Scatter( |
| x=timestamps, y=engagement, |
| mode="lines", |
| line=dict(color="#D4A017", width=2.5), |
| fill="tozeroy", |
| fillcolor="rgba(212,160,23,0.06)", |
| hovertemplate="<b>%{x:.1f}s</b><br>Response: %{y:.0f}/100<extra></extra>", |
| )) |
|
|
| fig.update_layout( |
| template="plotly_dark", |
| paper_bgcolor="rgba(0,0,0,0)", |
| plot_bgcolor="rgba(0,0,0,0)", |
| margin=dict(l=45, r=15, t=15, b=40), |
| height=260, |
| xaxis=dict( |
| title="", |
| showgrid=True, |
| gridcolor="rgba(255,255,255,0.04)", |
| zeroline=False, |
| color="#666", |
| tickfont=dict(size=11, family="DM Mono, monospace"), |
| ), |
| yaxis=dict( |
| title="", |
| showgrid=True, |
| gridcolor="rgba(255,255,255,0.04)", |
| zeroline=False, |
| range=[0, 105], |
| color="#666", |
| tickfont=dict(size=11, family="DM Mono, monospace"), |
| ), |
| hoverlabel=dict( |
| bgcolor="#1a1a1a", |
| bordercolor="#D4A017", |
| font=dict(color="#fff", family="DM Mono, monospace", size=12), |
| ), |
| ) |
| return fig |
|
|
|
|
| |
|
|
| def generate_feedback(engagement, zones): |
| avg = float(np.mean(engagement)) |
| mx = float(np.max(engagement)) |
| mn = float(np.min(engagement)) |
|
|
| weak_mask = engagement < (avg * 0.65) |
| weak_pct = np.sum(weak_mask) / len(engagement) * 100 |
|
|
| hook_len = max(1, len(engagement) // 10) |
| hook_avg = float(np.mean(engagement[:hook_len])) |
| end_avg = float(np.mean(engagement[-hook_len:])) |
|
|
| if avg >= 72: |
| strength, strength_color = "STRONG", "#22C55E" |
| elif avg >= 50: |
| strength, strength_color = "AVERAGE", "#D4A017" |
| else: |
| strength, strength_color = "WEAK", "#EF4444" |
|
|
| zone_avgs = {name: float(np.mean(sig)) for name, sig in zones.items()} |
| weakest_zone = min(zone_avgs, key=zone_avgs.get) |
|
|
| recs = [] |
| if hook_avg < 55: |
| recs.append("**Weak hook** β first seconds don't grab. Lead with your strongest visual or a surprise.") |
| if end_avg < 45: |
| recs.append("**Flat ending** β viewers drop off. Add a payoff or cliffhanger in the final seconds.") |
| if weak_pct > 30: |
| recs.append(f"**{weak_pct:.0f}% dead air** β too many flat stretches. Cut or accelerate the lows.") |
|
|
| zone_recs = { |
| "Action": "More movement, gestures, or physical dynamics.", |
| "Attention": "Vary shots β add cuts, zooms, angle changes.", |
| "Speech": "More face time or recognizable elements.", |
| "Visual": "Flat texture. Try contrast, color grading, composition.", |
| "Frontal": "No clear thread. Give viewers something to follow.", |
| } |
| if zone_avgs[weakest_zone] < 55: |
| recs.append(f"**{weakest_zone} zone is dragging** β {zone_recs.get(weakest_zone, '')}") |
|
|
| if not recs: |
| recs.append("Solid cut. Minor gains from tightening the middle third.") |
|
|
| rec_lines = "\n".join(f"- {r}" for r in recs) |
|
|
| return zone_avgs, strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct, rec_lines |
|
|
|
|
| |
|
|
| def make_stats_html(strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct): |
| return f""" |
| <div class="stats-grid"> |
| <div class="stat-card stat-verdict" style="border-color: {strength_color}40;"> |
| <div class="stat-label">Verdict</div> |
| <div class="stat-value" style="color: {strength_color};">{strength}</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-label">Average</div> |
| <div class="stat-value">{avg:.0f}</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-label">Peak</div> |
| <div class="stat-value">{mx:.0f}</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-label">Floor</div> |
| <div class="stat-value">{mn:.0f}</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-label">Hook</div> |
| <div class="stat-value">{hook_avg:.0f}</div> |
| </div> |
| <div class="stat-card"> |
| <div class="stat-label">Dead air</div> |
| <div class="stat-value">{weak_pct:.0f}%</div> |
| </div> |
| </div> |
| """ |
|
|
|
|
| def make_zone_html(zone_avgs): |
| cards = "" |
| for zone_name, meta in ZONE_META.items(): |
| val = zone_avgs.get(zone_name, 0) |
| color = meta["color"] |
| idx = meta["icon"] |
|
|
| if val >= 70: |
| activity = "HIGH" |
| elif val >= 40: |
| activity = "MID" |
| else: |
| activity = "LOW" |
|
|
| cards += f""" |
| <div class="zone-card"> |
| <div class="zone-header"> |
| <span class="zone-idx" style="color: {color};">{idx}</span> |
| <span class="zone-name">{zone_name}</span> |
| <span class="zone-activity" style="color: {color};">{activity}</span> |
| </div> |
| <div class="zone-desc">{meta['desc']}</div> |
| <div class="zone-bar-track"> |
| <div class="zone-bar-fill" style="width:{val:.0f}%;background:{color};"></div> |
| </div> |
| <div class="zone-val" style="color: {color};">{val:.0f}</div> |
| </div>""" |
| return f'<div class="zone-stack">{cards}</div>' |
|
|
|
|
| |
|
|
| def process_video(video): |
| if video is None: |
| raise gr.Error("Upload a video first.") |
| timestamps, engagement, zones, duration = analyze_video(video) |
| chart = make_engagement_chart(timestamps, engagement) |
| zone_avgs, strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct, recs = generate_feedback(engagement, zones) |
| stats_html = make_stats_html(strength, strength_color, avg, mx, mn, hook_avg, end_avg, weak_pct) |
| zone_html = make_zone_html(zone_avgs) |
| return chart, stats_html, zone_html, recs |
|
|
|
|
| def process_comparison(video_a, video_b): |
| if video_a is None or video_b is None: |
| raise gr.Error("Upload both videos to compare.") |
|
|
| ts_a, eng_a, zones_a, dur_a = analyze_video(video_a) |
| ts_b, eng_b, zones_b, dur_b = analyze_video(video_b) |
|
|
| fig = go.Figure() |
| fig.add_trace(go.Scatter( |
| x=ts_a, y=eng_a, mode="lines", name="Video A", |
| line=dict(color="#D4A017", width=2.5), |
| )) |
| fig.add_trace(go.Scatter( |
| x=ts_b, y=eng_b, mode="lines", name="Video B", |
| line=dict(color="#3B82F6", width=2.5), |
| )) |
| fig.update_layout( |
| template="plotly_dark", |
| paper_bgcolor="rgba(0,0,0,0)", |
| plot_bgcolor="rgba(0,0,0,0)", |
| margin=dict(l=45, r=15, t=15, b=40), |
| height=300, |
| xaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.04)", zeroline=False, color="#666", |
| tickfont=dict(size=11, family="DM Mono, monospace")), |
| yaxis=dict(showgrid=True, gridcolor="rgba(255,255,255,0.04)", zeroline=False, range=[0, 105], color="#666", |
| tickfont=dict(size=11, family="DM Mono, monospace")), |
| legend=dict(font=dict(color="#999", family="DM Mono, monospace", size=11), |
| bgcolor="rgba(0,0,0,0)", orientation="h", y=1.08), |
| hoverlabel=dict(bgcolor="#1a1a1a", bordercolor="#D4A017", |
| font=dict(color="#fff", family="DM Mono, monospace", size=12)), |
| ) |
|
|
| avg_a, avg_b = np.mean(eng_a), np.mean(eng_b) |
| winner = "A" if avg_a > avg_b else "B" |
| diff = abs(avg_a - avg_b) |
|
|
| summary = f"""| | Video A | Video B | |
| |--|---------|---------| |
| | **Average** | {avg_a:.0f} | {avg_b:.0f} | |
| | **Peak** | {np.max(eng_a):.0f} | {np.max(eng_b):.0f} | |
| | **Floor** | {np.min(eng_a):.0f} | {np.min(eng_b):.0f} | |
| | **Duration** | {dur_a:.1f}s | {dur_b:.1f}s | |
| |
| **Video {winner}** wins by **+{diff:.0f}** avg response. |
| """ |
| return fig, summary |
|
|
|
|
| |
|
|
| CSS = """ |
| @import url('https://fonts.googleapis.com/css2?family=DM+Mono:wght@300;400;500&family=Syne:wght@400;600;700;800&display=swap'); |
| |
| /* Global overrides */ |
| .gradio-container { |
| max-width: 900px !important; |
| margin: 0 auto !important; |
| font-family: 'Syne', sans-serif !important; |
| background: #0A0A0B !important; |
| } |
| .dark .gradio-container { background: #0A0A0B !important; } |
| |
| /* Video upload area */ |
| .video-upload { |
| background: rgba(255,255,255,0.03) !important; |
| border: 1px solid rgba(255,255,255,0.08) !important; |
| border-radius: 14px !important; |
| overflow: hidden; |
| } |
| .video-upload .wrap, .video-upload video { |
| border-radius: 14px !important; |
| } |
| |
| /* Header */ |
| .hero-header { |
| text-align: center; |
| padding: 48px 20px 32px; |
| position: relative; |
| } |
| .hero-header::before { |
| content: ''; |
| position: absolute; |
| top: 0; left: 50%; |
| transform: translateX(-50%); |
| width: 400px; height: 400px; |
| background: radial-gradient(circle, rgba(212,160,23,0.08) 0%, transparent 70%); |
| pointer-events: none; |
| } |
| .hero-title { |
| font-family: 'Syne', sans-serif; |
| font-size: 56px; |
| font-weight: 800; |
| letter-spacing: -2px; |
| color: #FAFAFA; |
| margin: 0; |
| line-height: 1; |
| } |
| .hero-title span { color: #D4A017; } |
| .hero-sub { |
| font-family: 'DM Mono', monospace; |
| font-size: 13px; |
| color: #666; |
| margin-top: 12px; |
| letter-spacing: 0.5px; |
| line-height: 1.6; |
| max-width: 600px; |
| margin-left: auto; |
| margin-right: auto; |
| } |
| |
| /* Tabs */ |
| .tab-nav { border: none !important; justify-content: center !important; } |
| .tab-nav button { |
| font-family: 'DM Mono', monospace !important; |
| font-size: 12px !important; |
| letter-spacing: 1px !important; |
| text-transform: uppercase !important; |
| color: #555 !important; |
| border: none !important; |
| background: transparent !important; |
| padding: 10px 20px !important; |
| } |
| .tab-nav button.selected { |
| color: #D4A017 !important; |
| border-bottom: 2px solid #D4A017 !important; |
| background: transparent !important; |
| } |
| |
| /* Stat cards */ |
| .stats-grid { |
| display: grid; |
| grid-template-columns: repeat(6, 1fr); |
| gap: 10px; |
| margin-top: 8px; |
| } |
| .stat-card { |
| background: rgba(255,255,255,0.03); |
| border: 1px solid rgba(255,255,255,0.06); |
| border-radius: 10px; |
| padding: 14px 16px; |
| text-align: center; |
| } |
| .stat-card.stat-verdict { |
| border-width: 2px; |
| } |
| .stat-label { |
| font-family: 'DM Mono', monospace; |
| font-size: 10px; |
| text-transform: uppercase; |
| letter-spacing: 1.5px; |
| color: #555; |
| margin-bottom: 6px; |
| } |
| .stat-value { |
| font-family: 'Syne', sans-serif; |
| font-size: 22px; |
| font-weight: 700; |
| color: #FAFAFA; |
| } |
| |
| /* Zone cards */ |
| .zone-stack { |
| display: flex; |
| flex-direction: column; |
| gap: 8px; |
| } |
| .zone-card { |
| background: rgba(255,255,255,0.02); |
| border: 1px solid rgba(255,255,255,0.06); |
| border-radius: 10px; |
| padding: 14px 16px; |
| } |
| .zone-header { |
| display: flex; |
| align-items: center; |
| gap: 10px; |
| margin-bottom: 4px; |
| } |
| .zone-idx { |
| font-family: 'DM Mono', monospace; |
| font-size: 11px; |
| font-weight: 500; |
| opacity: 0.7; |
| } |
| .zone-name { |
| font-family: 'Syne', sans-serif; |
| font-size: 14px; |
| font-weight: 600; |
| color: #E0E0E0; |
| flex: 1; |
| } |
| .zone-activity { |
| font-family: 'DM Mono', monospace; |
| font-size: 10px; |
| letter-spacing: 1px; |
| font-weight: 500; |
| } |
| .zone-desc { |
| font-family: 'DM Mono', monospace; |
| font-size: 11px; |
| color: #555; |
| margin-bottom: 10px; |
| line-height: 1.4; |
| } |
| .zone-bar-track { |
| height: 4px; |
| background: rgba(255,255,255,0.06); |
| border-radius: 2px; |
| overflow: hidden; |
| margin-bottom: 6px; |
| } |
| .zone-bar-fill { |
| height: 100%; |
| border-radius: 2px; |
| transition: width 0.8s cubic-bezier(0.22, 1, 0.36, 1); |
| } |
| .zone-val { |
| font-family: 'DM Mono', monospace; |
| font-size: 12px; |
| font-weight: 500; |
| text-align: right; |
| } |
| |
| /* Plot container */ |
| .plot-container { |
| border-radius: 12px; |
| overflow: hidden; |
| } |
| |
| /* Blocks overrides */ |
| .block { background: transparent !important; border: none !important; box-shadow: none !important; } |
| .label-wrap { display: none !important; } |
| .prose { color: #999 !important; } |
| .prose h3 { color: #E0E0E0 !important; font-family: 'Syne', sans-serif !important; } |
| .prose strong { color: #FAFAFA !important; } |
| .prose table { border-color: rgba(255,255,255,0.08) !important; } |
| .prose th, .prose td { border-color: rgba(255,255,255,0.08) !important; color: #999 !important; } |
| .prose th { color: #ccc !important; } |
| .prose a { color: #D4A017 !important; } |
| |
| /* Video component */ |
| .video-container { border-radius: 12px; overflow: hidden; } |
| |
| /* Button */ |
| button.primary { |
| background: #D4A017 !important; |
| border: none !important; |
| color: #0A0A0B !important; |
| font-family: 'Syne', sans-serif !important; |
| font-weight: 700 !important; |
| letter-spacing: 0.5px !important; |
| border-radius: 10px !important; |
| } |
| button.primary:hover { |
| background: #E8B12A !important; |
| } |
| |
| /* Markdown feedback section */ |
| .feedback-section .prose { |
| font-family: 'DM Mono', monospace !important; |
| font-size: 13px !important; |
| line-height: 1.7 !important; |
| } |
| |
| /* Responsive */ |
| @media (max-width: 768px) { |
| .stats-grid { grid-template-columns: repeat(3, 1fr); } |
| .hero-title { font-size: 36px; } |
| } |
| """ |
|
|
|
|
| |
|
|
| theme = gr.themes.Base( |
| primary_hue=gr.themes.colors.amber, |
| neutral_hue=gr.themes.colors.zinc, |
| ).set( |
| body_background_fill="#0A0A0B", |
| body_text_color="#999999", |
| block_background_fill="transparent", |
| block_border_width="0px", |
| block_shadow="none", |
| input_background_fill="#111113", |
| input_border_color="rgba(255,255,255,0.08)", |
| input_border_width="1px", |
| button_primary_background_fill="#D4A017", |
| button_primary_text_color="#0A0A0B", |
| ) |
|
|
| with gr.Blocks(theme=theme, css=CSS, title="TRIBE2") as demo: |
|
|
| gr.HTML(""" |
| <div class="hero-header"> |
| <h1 class="hero-title">TRIBE<span>2</span></h1> |
| <p class="hero-sub"> |
| Predict where your video loses the viewer. Powered by Meta's brain-encoding |
| model β maps cortical response across 20,000 vertices in real time. |
| </p> |
| </div> |
| """) |
|
|
| with gr.Tab("Analyze"): |
| video_input = gr.Video(label="", height=360, elem_classes=["video-upload"]) |
| analyze_btn = gr.Button("Run analysis", variant="primary", size="lg") |
| stats_html = gr.HTML() |
| chart = gr.Plot() |
| with gr.Row(): |
| with gr.Column(scale=2, elem_classes=["feedback-section"]): |
| feedback = gr.Markdown() |
| with gr.Column(scale=1): |
| zone_html = gr.HTML() |
|
|
| analyze_btn.click( |
| fn=process_video, |
| inputs=[video_input], |
| outputs=[chart, stats_html, zone_html, feedback], |
| ) |
|
|
| with gr.Tab("A/B Compare"): |
| with gr.Row(): |
| vid_a = gr.Video(label="Video A", elem_classes=["video-upload"]) |
| vid_b = gr.Video(label="Video B", elem_classes=["video-upload"]) |
| compare_btn = gr.Button("Compare", variant="primary", size="lg") |
| compare_chart = gr.Plot() |
| compare_md = gr.Markdown() |
|
|
| compare_btn.click( |
| fn=process_comparison, |
| inputs=[vid_a, vid_b], |
| outputs=[compare_chart, compare_md], |
| ) |
|
|
| with gr.Tab("How it works"): |
| gr.Markdown("""### The model |
| |
| **TRIBE v2** is Meta's brain-encoding foundation model. It predicts fMRI-level cortical |
| activation from video, audio, and text using three extractors: |
| |
| V-JEPA2 (vision) + Wav2Vec-BERT 2.0 (audio) + LLaMA 3.2 (language) |
| |
| The output is a prediction across ~20,000 cortical vertices on the fsaverage5 surface. |
| We aggregate those into five zones and derive a composite engagement signal. |
| |
| Predictions carry a ~5s hemodynamic offset inherent to fMRI. |
| Treat each timestamp as an approximate editing window. |
| |
| [Model card](https://huggingface.co/facebook/tribev2) |
| """) |
|
|
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|