File size: 13,702 Bytes
54d61b3
 
1e1a3cf
33136cd
1458172
54d61b3
33136cd
1e1a3cf
2050261
54d61b3
105db59
 
 
 
 
 
33136cd
1458172
33136cd
 
1458172
 
 
 
 
 
33136cd
 
1458172
33136cd
1458172
 
 
 
 
 
 
33136cd
1458172
33136cd
1458172
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33136cd
1458172
 
 
 
 
 
33136cd
1458172
 
 
 
 
33136cd
1458172
33136cd
 
 
 
1458172
 
 
 
 
 
33136cd
 
 
 
 
1458172
33136cd
 
 
1458172
 
33136cd
 
 
 
105db59
 
 
 
 
d8d9431
105db59
1458172
105db59
90b9b57
2050261
105db59
33136cd
 
90b9b57
105db59
33136cd
 
 
105db59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d8d9431
105db59
 
 
 
 
 
 
 
54d61b3
105db59
54d61b3
 
 
 
105db59
d8d9431
105db59
d8d9431
105db59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e1a3cf
 
105db59
 
 
 
 
 
 
d8d9431
105db59
d8d9431
105db59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2050261
 
54d61b3
105db59
33136cd
 
 
90b9b57
1458172
 
90b9b57
 
 
 
 
105db59
2050261
105db59
 
 
90b9b57
105db59
90b9b57
33136cd
 
 
 
 
 
 
 
 
 
 
1458172
 
 
 
33136cd
 
90b9b57
42bdf92
90b9b57
42bdf92
 
 
90b9b57
42bdf92
1458172
 
 
33136cd
 
90b9b57
 
105db59
90b9b57
 
105db59
90b9b57
 
d8d9431
 
105db59
90b9b57
d8d9431
105db59
d8d9431
90b9b57
105db59
42bdf92
 
 
 
1458172
 
105db59
33136cd
 
 
2050261
105db59
 
 
42bdf92
90b9b57
105db59
 
90b9b57
42bdf92
90b9b57
42bdf92
90b9b57
 
 
42bdf92
 
90b9b57
 
42bdf92
 
90b9b57
105db59
 
2050261
42bdf92
 
90b9b57
105db59
2050261
 
 
90b9b57
42bdf92
90b9b57
42bdf92
 
90b9b57
42bdf92
 
90b9b57
 
42bdf92
90b9b57
 
 
42bdf92
105db59
 
33136cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import sys
import os
import types
import logging
import re

# Shim for removed audioop module (Python 3.13+)
if 'audioop' not in sys.modules:
    sys.modules['audioop'] = types.ModuleType('audioop')

import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Agg')

logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

# ---------------------------------------------------------------------------
# Monkey-patch TRIBE's ExtractWordsFromAudio to build word-level events
# WITHOUT calling whisperx (which requires CUDA libs unavailable on CPU).
#
# Instead, we use a simple heuristic: split the transcript text into words
# and distribute them evenly across the audio duration. This gives TRIBE
# enough word-level signal for its text encoder without needing ASR.
# ---------------------------------------------------------------------------
def _patched_get_transcript_from_audio(wav_filename, language="english"):
    """CPU-safe replacement that creates word events from audio duration.

    When the audio was generated from known text (gTTS), the global
    CURRENT_SCRIPT_TEXT will contain that text.  Otherwise we create
    a minimal placeholder so TRIBE's pipeline doesn't crash.
    """
    import pandas as pd
    import soundfile as sf
    from pathlib import Path

    wav_filename = Path(wav_filename)

    # Get audio duration
    try:
        info = sf.info(str(wav_filename))
        duration = info.duration
    except Exception:
        duration = 30.0  # fallback

    # Use the known script text if available, otherwise a placeholder
    text = _CURRENT_SCRIPT_TEXT or "audio content placeholder"

    # Tokenize into words
    raw_words = text.split()
    if not raw_words:
        return pd.DataFrame(columns=["text", "start", "duration", "sequence_id", "sentence"])

    # Split into sentences (rough: split on . ! ?)
    sentences = re.split(r'(?<=[.!?])\s+', text)
    sentences = [s.strip() for s in sentences if s.strip()]
    if not sentences:
        sentences = [text]

    # Distribute words evenly across the audio duration
    word_duration = duration / len(raw_words)
    words = []
    word_idx = 0
    for sent_idx, sentence in enumerate(sentences):
        sent_words = sentence.split()
        for w in sent_words:
            if word_idx >= len(raw_words):
                break
            words.append({
                "text": w.replace('"', ''),
                "start": word_idx * word_duration,
                "duration": word_duration * 0.9,
                "sequence_id": sent_idx,
                "sentence": sentence.replace('"', ''),
            })
            word_idx += 1

    return pd.DataFrame(words)


# Global to pass text from the analyze function to the monkey-patch
_CURRENT_SCRIPT_TEXT = None


def apply_patches():
    """Patch TRIBE's ExtractWordsFromAudio to avoid whisperx/CUDA dependency."""
    try:
        from tribev2.eventstransforms import ExtractWordsFromAudio
        ExtractWordsFromAudio._get_transcript_from_audio = staticmethod(
            _patched_get_transcript_from_audio
        )
        logger.info("Patched ExtractWordsFromAudio (CPU-safe, no whisperx)")
    except Exception as e:
        logger.warning(f"Could not patch ExtractWordsFromAudio: {e}")

# Apply patches at import time
apply_patches()

# ---------------------------------------------------------------------------
# Model loading
# ---------------------------------------------------------------------------
model = None

def load_model():
    global model
    if model is not None:
        return "βœ… Already loaded!"
    try:
        apply_patches()  # re-apply in case import order matters
        from tribev2 import TribeModel
        model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="/tmp/tribe_cache")
        return "βœ… Model loaded!"
    except Exception as e:
        import traceback
        traceback.print_exc()
        return f"❌ Error loading model: {str(e)}"

# ---------------------------------------------------------------------------
# Brain region definitions (approximate vertex ranges on fsaverage5)
# ---------------------------------------------------------------------------
REGIONS = [
    ("Visual cortex",           0.00, 0.15, "#378ADD"),
    ("Auditory cortex",         0.15, 0.30, "#D85A30"),
    ("Language (Broca's area)", 0.30, 0.45, "#7F77DD"),
    ("Prefrontal (attention)",  0.45, 0.62, "#1D9E75"),
    ("Temporal (memory)",       0.62, 0.78, "#BA7517"),
    ("Emotion (limbic)",        0.78, 1.00, "#D4537E"),
]

def score_predictions(preds):
    avg = np.mean(np.abs(preds), axis=0)
    global_max = avg.max() + 1e-8
    half = len(avg) // 2
    scores = {}
    for name, s, e, _ in REGIONS:
        start, end = int(half * s), int(half * e)
        scores[name] = round(float(np.mean(avg[start:end]) / global_max * 100), 1)
    return scores, round(sum(scores.values()) / len(scores), 1)

def make_brain_plot(preds):
    try:
        from nilearn import plotting, datasets
        avg = np.mean(np.abs(preds), axis=0)
        avg_norm = (avg - avg.min()) / (avg.max() - avg.min() + 1e-8)
        half = len(avg_norm) // 2
        fsaverage = datasets.fetch_surf_fsaverage("fsaverage5")
        fig, axes = plt.subplots(1, 2, figsize=(14, 5), subplot_kw={"projection": "3d"})
        fig.patch.set_facecolor("#111111")
        plotting.plot_surf_stat_map(fsaverage.infl_left, avg_norm[:half], hemi="left",
            view="lateral", colorbar=True, cmap="hot", title="Left hemisphere", axes=axes[0], figure=fig)
        plotting.plot_surf_stat_map(fsaverage.infl_right, avg_norm[half:], hemi="right",
            view="lateral", colorbar=True, cmap="hot", title="Right hemisphere", axes=axes[1], figure=fig)
        plt.tight_layout()
        plt.savefig("/tmp/brain_map.png", dpi=130, bbox_inches="tight", facecolor="#111111")
        plt.close()
        return "/tmp/brain_map.png"
    except Exception as e:
        print(f"Brain plot error: {e}")
        return None

def make_score_chart(scores, overall):
    fig, ax = plt.subplots(figsize=(9, 4))
    fig.patch.set_facecolor("#1a1a1a")
    ax.set_facecolor("#1a1a1a")
    names  = [r[0] for r in REGIONS]
    colors = [r[3] for r in REGIONS]
    vals   = [scores.get(n, 0) for n in names]
    bars = ax.barh(names, vals, color=colors, height=0.55)
    ax.set_xlim(0, 100)
    ax.axvline(70, color="#888", linestyle="--", linewidth=1, alpha=0.6)
    ax.set_xlabel("Activation score", color="#ccc", fontsize=11)
    ax.set_title(f"Brain region activation  |  Overall: {overall}/100",
                 color="white", fontsize=13, fontweight="bold", pad=12)
    ax.tick_params(colors="#ccc")
    for spine in ax.spines.values():
        spine.set_edgecolor("#333")
    for bar, val in zip(bars, vals):
        ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height() / 2,
                f"{val}", va="center", color="white", fontsize=10, fontweight="bold")
    plt.tight_layout()
    plt.savefig("/tmp/score_chart.png", dpi=130, bbox_inches="tight", facecolor="#1a1a1a")
    plt.close()
    return "/tmp/score_chart.png"

def generate_suggestions(scores, overall):
    tips = []
    if scores.get("Prefrontal (attention)", 100) < 70:
        tips.append("β†’ Open with a bold question or surprising fact to boost attention")
    if scores.get("Emotion (limbic)", 100) < 70:
        tips.append("β†’ Add emotional language β€” 'imagine', 'feel', personal stories")
    if scores.get("Temporal (memory)", 100) < 70:
        tips.append("β†’ Include specific numbers or data points to improve memorability")
    if scores.get("Visual cortex", 100) < 70:
        tips.append("β†’ Use more visual language β€” describe what viewers will 'see'")
    if scores.get("Language (Broca's area)", 100) < 70:
        tips.append("β†’ Break long sentences into shorter, punchier ones")
    if scores.get("Auditory cortex", 100) < 70:
        tips.append("β†’ Add rhythm and repetition β€” the brain responds to sound patterns")
    if not tips:
        tips.append("β†’ Excellent! Consider adding a strong call-to-action at the end")
    status = "🟒 Strong" if overall >= 75 else "🟑 Good, needs polish" if overall >= 55 else "πŸ”΄ Needs work"
    return f"**Overall: {overall}/100 β€” {status}**\n\n" + "\n".join(tips)

# ---------------------------------------------------------------------------
# Main analysis function
# ---------------------------------------------------------------------------
def analyze(input_mode, script_text, audio_file, progress=gr.Progress()):
    global _CURRENT_SCRIPT_TEXT

    if input_mode == "Text" and (not script_text or not script_text.strip()):
        return None, None, "⚠️ Please paste your script text first.", None
    if input_mode == "Audio" and audio_file is None:
        return None, None, "⚠️ Please upload an audio file first.", None

    if model is None:
        progress(0.1, desc="Loading TRIBE v2 model (first time ~5 mins)...")
        msg = load_model()
        if "Error" in msg:
            return None, None, msg, None

    try:
        if input_mode == "Text":
            progress(0.2, desc="Converting text to speech...")

            from gtts import gTTS
            from langdetect import detect

            text = script_text.strip()
            lang = detect(text)
            audio_path = "/tmp/script_audio.mp3"
            tts = gTTS(text=text, lang=lang)
            tts.save(audio_path)

            # Store text so the monkey-patched transcriber can use it
            # instead of running ASR on the audio we just synthesised.
            _CURRENT_SCRIPT_TEXT = text

            progress(0.4, desc="Running TRIBE v2 on generated audio...")
            df = model.get_events_dataframe(audio_path=audio_path)

        else:
            import shutil
            progress(0.2, desc="Loading audio file...")
            ext = os.path.splitext(audio_file)[1] or ".mp3"
            audio_path = f"/tmp/input_audio{ext}"
            shutil.copy(audio_file, audio_path)

            # No known text for uploaded audio
            _CURRENT_SCRIPT_TEXT = None

            progress(0.4, desc="Running TRIBE v2 on audio...")
            df = model.get_events_dataframe(audio_path=audio_path)

        progress(0.6, desc="Predicting brain response...")
        preds, segments = model.predict(events=df)

        progress(0.75, desc="Scoring regions...")
        scores, overall = score_predictions(preds)

        progress(0.85, desc="Rendering maps...")
        brain_img   = make_brain_plot(preds)
        score_img   = make_score_chart(scores, overall)
        suggestions = generate_suggestions(scores, overall)

        np.save("/tmp/brain_predictions.npy", preds)
        progress(1.0, desc="Done!")
        return brain_img, score_img, suggestions, "/tmp/brain_predictions.npy"

    except Exception as e:
        import traceback
        full_error = traceback.format_exc()
        print(full_error)
        return None, None, f"❌ Error:\n{str(e)}\n\nFull traceback:\n{full_error}", None
    finally:
        _CURRENT_SCRIPT_TEXT = None

# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
css = "#title{text-align:center} #subtitle{text-align:center;color:#888;font-size:14px}"

with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), css=css) as demo:
    gr.Markdown("# 🧠 Script Brain Optimizer", elem_id="title")
    gr.Markdown("Analyze your script or audio β†’ real fMRI predictions via **TRIBE v2** β†’ iterate", elem_id="subtitle")

    with gr.Row():
        with gr.Column(scale=1):
            input_mode = gr.Radio(
                choices=["Text", "Audio"], value="Text",
                label="Input type",
                info="Text: paste your script | Audio: upload MP3/WAV"
            )
            script_input = gr.Textbox(
                label="Your script",
                placeholder="Paste your content script here...",
                lines=10, max_lines=20, visible=True
            )
            audio_input = gr.Audio(
                label="Upload audio file (MP3, WAV, M4A, FLAC)",
                type="filepath", sources=["upload"], visible=False
            )
            with gr.Row():
                clear_btn   = gr.Button("Clear", variant="secondary", scale=1)
                analyze_btn = gr.Button("🧠 Analyze", variant="primary", scale=3)
            suggestions_out = gr.Markdown(value="*Add your content and click Analyze...*")
            download_out    = gr.File(label="Download predictions (.npy)")

        with gr.Column(scale=2):
            brain_img_out = gr.Image(label="Brain activation map", height=320)
            score_img_out = gr.Image(label="Region scores", height=280)

    def toggle_mode(mode):
        return gr.update(visible=mode=="Text"), gr.update(visible=mode=="Audio")

    input_mode.change(fn=toggle_mode, inputs=[input_mode],
                      outputs=[script_input, audio_input])

    analyze_btn.click(fn=analyze, inputs=[input_mode, script_input, audio_input],
                      outputs=[brain_img_out, score_img_out, suggestions_out, download_out])

    clear_btn.click(
        fn=lambda: ("", None, None, None, "*Add your content and click Analyze...*", None),
        outputs=[script_input, audio_input, brain_img_out, score_img_out, suggestions_out, download_out]
    )

    gr.Markdown("---\n*Powered by [TRIBE v2](https://github.com/facebookresearch/tribev2) by Meta FAIR*")

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)