ceh-vedant commited on
Commit
d8d9431
Β·
verified Β·
1 Parent(s): 418dad6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -38
app.py CHANGED
@@ -2,7 +2,7 @@ import sys
2
  import os
3
  import types
4
 
5
- # ── Python 3.13 fix: stub out audioop ───────────────────────────────────────
6
  if 'audioop' not in sys.modules:
7
  sys.modules['audioop'] = types.ModuleType('audioop')
8
 
@@ -12,19 +12,15 @@ import matplotlib.pyplot as plt
12
  import matplotlib
13
  matplotlib.use('Agg')
14
 
15
- # ── Model loading ────────────────────────────────────────────────────────────
16
  model = None
17
 
18
  def load_model():
19
  global model
20
  if model is not None:
21
- return "βœ… Model already loaded!"
22
  try:
23
  from tribev2 import TribeModel
24
- model = TribeModel.from_pretrained(
25
- "facebook/tribev2",
26
- cache_folder="./tribe_cache"
27
- )
28
  return "βœ… Model loaded!"
29
  except Exception as e:
30
  return f"❌ Error: {str(e)}"
@@ -46,8 +42,7 @@ def score_predictions(preds):
46
  for name, s, e, _ in REGIONS:
47
  start, end = int(half * s), int(half * e)
48
  scores[name] = round(float(np.mean(avg[start:end]) / global_max * 100), 1)
49
- overall = round(sum(scores.values()) / len(scores), 1)
50
- return scores, overall
51
 
52
  def make_brain_plot(preds):
53
  try:
@@ -63,10 +58,9 @@ def make_brain_plot(preds):
63
  plotting.plot_surf_stat_map(fsaverage.infl_right, avg_norm[half:], hemi="right",
64
  view="lateral", colorbar=True, cmap="hot", title="Right hemisphere", axes=axes[1], figure=fig)
65
  plt.tight_layout()
66
- path = "./brain_map.png"
67
- plt.savefig(path, dpi=130, bbox_inches="tight", facecolor="#111111")
68
  plt.close()
69
- return path
70
  except Exception as e:
71
  print(f"Brain plot error: {e}")
72
  return None
@@ -91,10 +85,9 @@ def make_score_chart(scores, overall):
91
  ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height() / 2,
92
  f"{val}", va="center", color="white", fontsize=10, fontweight="bold")
93
  plt.tight_layout()
94
- path = "./score_chart.png"
95
- plt.savefig(path, dpi=130, bbox_inches="tight", facecolor="#1a1a1a")
96
  plt.close()
97
- return path
98
 
99
  def generate_suggestions(scores, overall):
100
  tips = []
@@ -127,20 +120,19 @@ def analyze_script(script_text, progress=gr.Progress()):
127
  from gtts import gTTS
128
  progress(0.2, desc="Converting script to speech...")
129
  tts = gTTS(text=script_text.strip(), lang="en", slow=False)
130
- audio_path = "./script_audio.mp3"
131
- tts.save(audio_path)
132
  progress(0.4, desc="Running TRIBE v2 prediction (1-3 mins)...")
133
- df = model.get_events_dataframe(audio_path=audio_path)
134
  preds, segments = model.predict(events=df)
135
  progress(0.7, desc="Scoring regions...")
136
  scores, overall = score_predictions(preds)
137
  progress(0.8, desc="Rendering maps...")
138
- brain_img = make_brain_plot(preds)
139
- score_img = make_score_chart(scores, overall)
140
  suggestions = generate_suggestions(scores, overall)
141
- np.save("./brain_predictions.npy", preds)
142
  progress(1.0, desc="Done!")
143
- return brain_img, score_img, suggestions, "./brain_predictions.npy"
144
  except Exception as e:
145
  return None, None, f"❌ Error:\n{str(e)}", None
146
 
@@ -148,14 +140,11 @@ css = "#title{text-align:center} #subtitle{text-align:center;color:#888;font-siz
148
 
149
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), css=css) as demo:
150
  gr.Markdown("# 🧠 Script Brain Optimizer", elem_id="title")
151
- gr.Markdown("Paste your script β†’ get real fMRI predictions via **TRIBE v2** β†’ iterate", elem_id="subtitle")
152
  with gr.Row():
153
  with gr.Column(scale=1):
154
- script_input = gr.Textbox(
155
- label="Your script",
156
- placeholder="Paste your content script here...",
157
- lines=12, max_lines=20
158
- )
159
  with gr.Row():
160
  clear_btn = gr.Button("Clear", variant="secondary", scale=1)
161
  analyze_btn = gr.Button("🧠 Analyze", variant="primary", scale=3)
@@ -165,16 +154,11 @@ with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), css=css) as demo:
165
  brain_img_out = gr.Image(label="Brain activation map", height=320)
166
  score_img_out = gr.Image(label="Region scores", height=280)
167
 
168
- analyze_btn.click(
169
- fn=analyze_script,
170
- inputs=[script_input],
171
- outputs=[brain_img_out, score_img_out, suggestions_out, download_out]
172
- )
173
- clear_btn.click(
174
- fn=lambda: ("", None, None, "*Paste a script and click Analyze...*", None),
175
- outputs=[script_input, brain_img_out, score_img_out, suggestions_out, download_out]
176
- )
177
  gr.Markdown("---\n*Powered by [TRIBE v2](https://github.com/facebookresearch/tribev2) by Meta FAIR*")
178
 
179
  if __name__ == "__main__":
180
- demo.launch()
 
2
  import os
3
  import types
4
 
5
+ # Python 3.13 audioop stub (not needed on 3.11 but harmless)
6
  if 'audioop' not in sys.modules:
7
  sys.modules['audioop'] = types.ModuleType('audioop')
8
 
 
12
  import matplotlib
13
  matplotlib.use('Agg')
14
 
 
15
  model = None
16
 
17
  def load_model():
18
  global model
19
  if model is not None:
20
+ return "βœ… Already loaded!"
21
  try:
22
  from tribev2 import TribeModel
23
+ model = TribeModel.from_pretrained("facebook/tribev2", cache_folder="./tribe_cache")
 
 
 
24
  return "βœ… Model loaded!"
25
  except Exception as e:
26
  return f"❌ Error: {str(e)}"
 
42
  for name, s, e, _ in REGIONS:
43
  start, end = int(half * s), int(half * e)
44
  scores[name] = round(float(np.mean(avg[start:end]) / global_max * 100), 1)
45
+ return scores, round(sum(scores.values()) / len(scores), 1)
 
46
 
47
  def make_brain_plot(preds):
48
  try:
 
58
  plotting.plot_surf_stat_map(fsaverage.infl_right, avg_norm[half:], hemi="right",
59
  view="lateral", colorbar=True, cmap="hot", title="Right hemisphere", axes=axes[1], figure=fig)
60
  plt.tight_layout()
61
+ plt.savefig("/tmp/brain_map.png", dpi=130, bbox_inches="tight", facecolor="#111111")
 
62
  plt.close()
63
+ return "/tmp/brain_map.png"
64
  except Exception as e:
65
  print(f"Brain plot error: {e}")
66
  return None
 
85
  ax.text(bar.get_width() + 1, bar.get_y() + bar.get_height() / 2,
86
  f"{val}", va="center", color="white", fontsize=10, fontweight="bold")
87
  plt.tight_layout()
88
+ plt.savefig("/tmp/score_chart.png", dpi=130, bbox_inches="tight", facecolor="#1a1a1a")
 
89
  plt.close()
90
+ return "/tmp/score_chart.png"
91
 
92
  def generate_suggestions(scores, overall):
93
  tips = []
 
120
  from gtts import gTTS
121
  progress(0.2, desc="Converting script to speech...")
122
  tts = gTTS(text=script_text.strip(), lang="en", slow=False)
123
+ tts.save("/tmp/script_audio.mp3")
 
124
  progress(0.4, desc="Running TRIBE v2 prediction (1-3 mins)...")
125
+ df = model.get_events_dataframe(audio_path="/tmp/script_audio.mp3")
126
  preds, segments = model.predict(events=df)
127
  progress(0.7, desc="Scoring regions...")
128
  scores, overall = score_predictions(preds)
129
  progress(0.8, desc="Rendering maps...")
130
+ brain_img = make_brain_plot(preds)
131
+ score_img = make_score_chart(scores, overall)
132
  suggestions = generate_suggestions(scores, overall)
133
+ np.save("/tmp/brain_predictions.npy", preds)
134
  progress(1.0, desc="Done!")
135
+ return brain_img, score_img, suggestions, "/tmp/brain_predictions.npy"
136
  except Exception as e:
137
  return None, None, f"❌ Error:\n{str(e)}", None
138
 
 
140
 
141
  with gr.Blocks(theme=gr.themes.Soft(primary_hue="indigo"), css=css) as demo:
142
  gr.Markdown("# 🧠 Script Brain Optimizer", elem_id="title")
143
+ gr.Markdown("Paste your script β†’ real fMRI predictions via **TRIBE v2** β†’ iterate", elem_id="subtitle")
144
  with gr.Row():
145
  with gr.Column(scale=1):
146
+ script_input = gr.Textbox(label="Your script",
147
+ placeholder="Paste your content script here...", lines=12, max_lines=20)
 
 
 
148
  with gr.Row():
149
  clear_btn = gr.Button("Clear", variant="secondary", scale=1)
150
  analyze_btn = gr.Button("🧠 Analyze", variant="primary", scale=3)
 
154
  brain_img_out = gr.Image(label="Brain activation map", height=320)
155
  score_img_out = gr.Image(label="Region scores", height=280)
156
 
157
+ analyze_btn.click(fn=analyze_script, inputs=[script_input],
158
+ outputs=[brain_img_out, score_img_out, suggestions_out, download_out])
159
+ clear_btn.click(fn=lambda: ("", None, None, "*Paste a script and click Analyze...*", None),
160
+ outputs=[script_input, brain_img_out, score_img_out, suggestions_out, download_out])
 
 
 
 
 
161
  gr.Markdown("---\n*Powered by [TRIBE v2](https://github.com/facebookresearch/tribev2) by Meta FAIR*")
162
 
163
  if __name__ == "__main__":
164
+ demo.launch(server_name="0.0.0.0", server_port=7860)