wangleiofficial commited on
Commit
ee2bd6b
·
verified ·
1 Parent(s): 8b4a5d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +477 -297
app.py CHANGED
@@ -1,24 +1,13 @@
1
- # Improved version of your LocPred-Prok app
2
- # (Tailwind layout, responsive SVG, download PNG, dark/light auto theme)
3
-
4
- # NOTE: This is a placeholder structure.
5
- # I will generate the full, ready-to-run file in the next update.
6
-
7
- # --- START OF FILE ---
8
-
9
- # Improved, complete LocPred-Prok Gradio app
10
- # Features added/fixed:
11
- # - Responsive, centered SVG that supports horizontal/circular layouts and high-res rendering
12
- # - Dark / light automatic color adaptation (via prefers-color-scheme)
13
- # - Client-side SVG -> PNG download buttons (no server libs required)
14
- # - Tailwind CDN used for layout utilities (works inside Gradio HTML panel)
15
- # - Tailwind-like alignment applied to main layout
16
- # - Keeps original model loading and prediction logic
17
-
18
  import os
19
- import json
20
  import re
 
21
  import uuid
 
 
 
 
 
22
  import torch
23
  import torch.nn as nn
24
  import torch.nn.functional as F
@@ -27,26 +16,31 @@ import matplotlib.pyplot as plt
27
  import numpy as np
28
  from transformers import AutoTokenizer, AutoModel
29
 
30
- # ---------- Environment & cache (same as original) ----------
 
 
 
 
 
 
 
31
  plt.switch_backend('Agg')
32
  os.environ["HF_HOME"] = "/tmp/hf_cache"
33
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
34
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
35
 
36
- import shutil
37
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
38
  shutil.rmtree(path, ignore_errors=True)
39
  os.makedirs(path, exist_ok=True)
40
 
41
- # ---------- Model architecture (unchanged core, with attention output) ----------
42
  class AttentionPooling(nn.Module):
43
  def __init__(self, d_model):
44
  super().__init__()
45
  self.attention_net = nn.Linear(d_model, 1)
46
 
47
  def forward(self, x, mask):
48
- # x: [B, L, D], mask: [B, L]
49
- attn_logits = self.attention_net(x).squeeze(2) # [B, L]
50
  attn_logits = attn_logits.masked_fill(mask == 0, -1e9)
51
  attn_weights = F.softmax(attn_logits, dim=1)
52
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
@@ -60,7 +54,11 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
60
  self.tok_projector = nn.Linear(d_model, projection_dim)
61
  fused_dim = projection_dim * 2
62
  self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
63
- self.classifier_head = nn.Sequential(nn.LayerNorm(fused_dim), nn.Linear(fused_dim, fused_dim * 2), nn.ReLU(), nn.Dropout(dropout), nn.Linear(fused_dim * 2, num_classes))
 
 
 
 
64
 
65
  def forward(self, cls_embedding, token_embeddings, mask):
66
  z_cls = self.cls_projector(cls_embedding)
@@ -73,236 +71,384 @@ class ProtDualBranchEnhancedClassifier(nn.Module):
73
  z_fused_gated = z_fused_concat * gate_values
74
  return self.classifier_head(z_fused_gated), pooling_weights
75
 
76
- # ---------- Load PLM + classifier ----------
77
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
79
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
80
  LABEL_MAP_PATH = "label_map.json"
81
 
82
- if not os.path.exists(LABEL_MAP_PATH):
83
- raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
84
- if not os.path.exists(CLASSIFIER_PATH):
85
- raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
86
-
87
- with open(LABEL_MAP_PATH, 'r') as f:
88
- label_to_idx = json.load(f)
89
- idx_to_label = {v: k for k, v in label_to_idx.items()}
90
- NUM_CLASSES = len(idx_to_label)
91
- D_MODEL = 640
92
-
93
- print("🔹 Loading models...")
94
- tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
95
- plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
96
- classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
97
- classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
98
- classifier.eval()
99
- print("✅ Models loaded and ready.")
100
-
101
- # ---------- SVG Generator with layout options and responsive wrapper ----------
102
- def generate_bacterial_svg(target_class, layout='circular', high_res=False):
103
- # Normalize target
104
- target = target_class.lower() if target_class else ""
105
-
106
- # Determine active compartments
107
- is_sec = "extracellular" in target or "secreted" in target
108
- is_om = "outer membrane" in target
109
- is_peri = "periplasm" in target
110
- is_cw = "cell wall" in target
111
- is_im = "plasma membrane" in target or "inner membrane" in target
112
- is_cyto = "cytoplasm" in target or "cytosol" in target
113
-
114
- # Color tokens using CSS variables (support dark/light)
115
- # Colors are referenced as var(--color-...)
116
- c = {
117
- "hl_stroke": "var(--hl-stroke)", "hl_fill": "var(--hl-fill)", "hl_text": "var(--hl-text)", "hl_dot": "var(--hl-dot)",
118
- "bg_stroke": "var(--bg-stroke)", "bg_fill_om": "var(--bg-fill-om)", "bg_fill_im": "var(--bg-fill-im)",
119
- "bg_text": "var(--bg-text)", "bg_line": "var(--bg-line)", "bg_dot": "var(--bg-dot)"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  }
121
 
122
- def style(active, base_fill, base_stroke, w_act="4", w_norm="2"):
123
- if active:
124
- return c['hl_fill'], c['hl_stroke'], w_act
125
- return base_fill, base_stroke, w_norm
126
-
127
- om_f, om_s, om_w = style(is_peri, c['bg_fill_om'], c['hl_stroke'] if is_om else c['bg_stroke'])
128
- cw_s = c['hl_stroke'] if is_cw else "var(--muted)"
129
- cw_w, cw_d = ("3", "0") if is_cw else ("1.5", "6,4")
130
- im_f, im_s, im_w = style(is_cyto, c['bg_fill_im'], c['hl_stroke'] if is_im else c['bg_stroke'])
131
-
132
- def label_style(active):
133
- if active: return c['hl_text'], 'bold', c['hl_stroke'], '2.5', c['hl_dot'], '5'
134
- return c['bg_text'], 'normal', c['bg_line'], '1.5', c['bg_dot'], '3'
135
-
136
- l_sec = label_style(is_sec)
137
- l_om = label_style(is_om)
138
- l_peri = label_style(is_peri)
139
- l_cw = label_style(is_cw)
140
- l_im = label_style(is_im)
141
- l_cyto = label_style(is_cyto)
142
-
143
- # Size and viewBox (increase resolution if high_res)
144
- base_w, base_h = (1200, 600) if high_res else (800, 420)
145
- viewbox = f"0 0 {base_w} {base_h}"
146
-
147
- # Choose layout: circular or horizontal
148
- if layout == 'horizontal':
149
- # Place cell on left, labels on right in a row
150
- bx, by = int(base_w * 0.35), int(base_h * 0.5)
151
- tx = int(base_w * 0.75)
152
  else:
153
- # circular: center cell and labels on right
154
- bx, by = int(base_w * 0.35), int(base_h * 0.5)
155
- tx = int(base_w * 0.75)
156
-
157
- # Anchor points relative to center
158
- targets = {
159
- 'sec': (bx, by - 180),
160
- 'om': (bx + 140, by - 120),
161
- 'peri': (bx + 120, by - 90),
162
- 'cw': (bx + 100, by - 70),
163
- 'im': (bx + 70, by - 50),
164
- 'cyto': (bx, by)
 
 
 
165
  }
166
 
167
- # label Y positions
168
- text_y = {
169
- 'sec': int(base_h*0.12), 'om': int(base_h*0.22), 'peri': int(base_h*0.32),
170
- 'cw': int(base_h*0.42), 'im': int(base_h*0.62), 'cyto': int(base_h*0.78)
 
 
 
171
  }
172
 
173
- def draw_connector(key, style_tuple, label_text):
174
- txt_col, weight, line_col, width, dot_col, r = style_tuple
175
- tx_pos, ty_pos = tx, text_y[key]
176
- ex, ey = targets[key]
177
- c1x, c1y = tx_pos - int(base_w*0.08), ty_pos
178
- c2x, c2y = ex + int(base_w*0.06), ey
179
- path = f"M {tx_pos - 10} {ty_pos - 5} C {c1x} {c1y}, {c2x} {c2y}, {ex} {ey}"
180
- return f'''
181
- <g>
182
- <text x="{tx_pos}" y="{ty_pos}" fill="{txt_col}" font-weight="{weight}" font-size="14" font-family="Inter, Arial">{label_text}</text>
183
- <path d="{path}" fill="none" stroke="{line_col}" stroke-width="{width}" stroke-linecap="round" stroke-linejoin="round" />
184
- <circle cx="{ex}" cy="{ey}" r="{r}" fill="{dot_col}" stroke="white" stroke-width="1" />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  </g>
186
  '''
187
-
188
- # Draw cell shapes: use rounded rects for membranes (keeping original geometry scaled)
189
- svg_shapes = f'''
190
- <g transform="translate({bx}, {by})">
191
- <rect x="{-150}" y="{-150}" width="300" height="300" rx="150" ry="150" fill="{om_f}" stroke="{om_s}" stroke-width="{om_w}" />
192
- <rect x="{-110}" y="{-110}" width="220" height="220" rx="110" ry="110" fill="none" stroke="{cw_s}" stroke-width="{cw_w}" stroke-dasharray="{cw_d}" />
193
- <rect x="{-70}" y="{-70}" width="140" height="140" rx="70" ry="70" fill="{im_f}" stroke="{im_s}" stroke-width="{im_w}" />
194
- <g opacity="0.45">
195
- <path d="M -30 -20 Q 0 -60 30 -20 T 60 -10" fill="none" stroke="var(--muted)" stroke-width="3" />
196
- <circle cx="-40" cy="30" r="3" fill="var(--muted)" /> <circle cx="20" cy="40" r="3" fill="var(--muted)" />
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
  </g>
 
 
 
 
 
 
 
 
198
  </g>
199
  '''
200
 
201
- # Compose connectors
202
- connectors = "".join([
203
- draw_connector('sec', l_sec, 'Extracellular / Secreted'),
204
- draw_connector('om', l_om, 'Outer Membrane'),
205
- draw_connector('peri', l_peri, 'Periplasm'),
206
- draw_connector('cw', l_cw, 'Cell Wall'),
207
- draw_connector('im', l_im, 'Inner Membrane'),
208
- draw_connector('cyto', l_cyto, 'Cytoplasm')
209
- ])
210
-
211
- svg_core = f'''<svg id="svg_core" width="100%" height="auto" viewBox="{viewbox}" xmlns="http://www.w3.org/2000/svg" role="img" aria-label="Bacterial localization diagram">
 
 
 
 
 
 
212
  <defs>
213
  <style><![CDATA[
214
  text {{ font-family: Inter, Arial, sans-serif; }}
215
  ]]></style>
216
  </defs>
217
- {svg_shapes}
218
  {connectors}
219
- </svg>'''
220
-
221
- # Create unique wrapper id so multiple calls don't clash
222
- uid = str(uuid.uuid4()).replace('-', '')
223
- wrapper = f"loc_svg_{uid}"
224
-
225
- # Build responsive HTML with inline JS to enable SVG/PNG download
226
- html = f'''
227
- <div id="{wrapper}" style="width:100%; text-align:center;">
228
- <div style="display:inline-block; max-width:100%; width:900px;">
229
- {svg_core}
230
- <div style="margin-top:8px; display:flex; gap:8px; justify-content:center; align-items:center;">
231
- <button id="btn_svg_{uid}" class="download-btn">Download SVG</button>
232
- <button id="btn_png_{uid}" class="download-btn">Download PNG</button>
233
- <div style="font-size:12px; color:var(--bg-text); align-self:center;">Layout: {layout.title()} {'· High-res' if high_res else ''}</div>
234
- </div>
235
- </div>
236
  </div>
 
 
237
  <script>
238
  (function(){{
239
- const wrapper = document.getElementById('{wrapper}');
240
- const svgEl = wrapper.querySelector('svg');
241
- const btnSvg = document.getElementById('btn_svg_{uid}');
242
- const btnPng = document.getElementById('btn_png_{uid}');
243
-
244
- // Helper: download file
245
- function download(filename, blob){{
246
- const url = URL.createObjectURL(blob);
247
- const a = document.createElement('a');
248
- a.href = url; a.download = filename; document.body.appendChild(a); a.click();
249
- setTimeout(()=>{{ URL.revokeObjectURL(url); a.remove(); }}, 100);
 
 
 
 
 
 
250
  }}
251
-
252
- btnSvg.addEventListener('click', ()=>{{
253
- const serializer = new XMLSerializer();
254
- let source = serializer.serializeToString(svgEl);
255
- // Add name spaces.
256
- if(!source.match(/^<svg[^>]+xmlns="http\:\/\/www\.w3\.org\/2000\/svg"/)){{
257
- source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
258
- }}
259
- if(!source.match(/^<svg[^>]+xmlns:xlink="http\:\/\/www\.w3\.org\/1999\/xlink"/)){{
260
- source = source.replace(/^<svg/, '<svg xmlns:xlink="http://www.w3.org/1999/xlink"');
261
- }}
262
- const blob = new Blob([source], {{type: 'image/svg+xml;charset=utf-8'}});
263
- download('locpred_diagram.svg', blob);
264
- }});
265
-
266
- btnPng.addEventListener('click', ()=>{{
267
- const serializer = new XMLSerializer();
268
- let source = serializer.serializeToString(svgEl);
269
- if(!source.match(/^<svg[^>]+xmlns="http\:\/\/www\.w3\.org\/2000\/svg"/)){{
270
- source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
271
- }}
272
- const image = new Image();
273
- const svgBlob = new Blob([source], {{type: 'image/svg+xml;charset=utf-8'}});
274
- const url = URL.createObjectURL(svgBlob);
275
- image.onload = function(){{
276
- const canvas = document.createElement('canvas');
277
- // scale canvas to natural image size (use 2x for better quality)
278
- canvas.width = image.width * 2;
279
- canvas.height = image.height * 2;
280
- const ctx = canvas.getContext('2d');
281
- // set background transparent-friendly
282
- ctx.fillStyle = 'white';
283
- ctx.fillRect(0,0,canvas.width, canvas.height);
284
- ctx.drawImage(image, 0, 0, canvas.width, canvas.height);
285
- canvas.toBlob(function(blob){{
286
- download('locpred_diagram.png', blob);
287
- }}, 'image/png');
288
- URL.revokeObjectURL(url);
289
- }};
290
- // In some environments the SVG does not have width/height set; set reasonable defaults
291
- image.src = url;
292
- }});
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
293
  }})();
294
  </script>
295
- '''
 
296
  return html
297
 
298
- # ---------- Attention heatmap (unchanged) ----------
299
  def draw_attention_heatmap_strip(weights, sequence):
300
  if weights.max() > 0:
301
  weights = (weights - weights.min()) / (weights.max() - weights.min())
302
  data = weights.reshape(1, -1)
303
  fig, ax = plt.subplots(figsize=(8, 1.5), dpi=150)
304
  im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
305
- ax.set_title('Sequence Attention Heatmap (High Color = Key Feature)', fontsize=10, fontweight='bold', color='#37474F', pad=10)
306
  ax.set_xlabel('Residue Position', fontsize=9)
307
  ax.set_yticks([])
308
  cbar = plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
@@ -313,105 +459,139 @@ def draw_attention_heatmap_strip(weights, sequence):
313
  plt.tight_layout()
314
  return fig
315
 
316
- # ---------- Prediction logic (exposes layout + high_res options) ----------
317
- def predict(sequence_input, layout_choice, high_res_flag):
 
 
 
 
 
 
318
  if not sequence_input or sequence_input.isspace():
319
- raise gr.Error('Empty Input')
 
320
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
321
  seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
322
  if not seq:
323
- raise gr.Error('Invalid Sequence')
324
-
325
- with torch.no_grad():
326
- inputs = tokenizer(seq, return_tensors='pt', truncation=True, max_length=1024).to(DEVICE)
327
- outputs = plm_model(**inputs)
328
- hidden_states = outputs.last_hidden_state
329
- cls_embedding = hidden_states[:, 0, :]
330
- token_embeddings = hidden_states[:, 1:-1, :]
331
- token_mask = inputs['attention_mask'][:, 1:-1]
332
- logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
333
- probs = F.softmax(logits, dim=1)[0]
334
-
335
- top_label = idx_to_label[torch.max(probs, dim=0)[1].item()]
336
- confidences = {idx_to_label[i]: float(p) for i, p in enumerate(probs)}
337
-
338
- svg = generate_bacterial_svg(top_label, layout=layout_choice, high_res=high_res_flag)
339
- w_np = pooling_weights[0].cpu().numpy()
340
- heatmap_plot = draw_attention_heatmap_strip(w_np, seq)
341
-
342
- return confidences, svg, heatmap_plot
343
-
344
- # ---------- UI (Tailwind CDN + Auto dark mode via CSS variables) ----------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
345
  layout_css = """
346
- /* Minimal overrides and CSS variables for dark/light */
347
- :root{
348
- --bg-fill-om: #F5F5F5; --bg-fill-im: #FAFAFA; --bg-stroke: #90A4AE; --muted: #B0BEC5;
349
- --hl-stroke: #D32F2F; --hl-fill: #FFEBEE; --hl-text: #B71C1C; --hl-dot: #D32F2F;
350
- --bg-text: #37474F; --bg-line: #CFD8DC; --bg-dot: #B0BEC5;
351
- }
352
- @media (prefers-color-scheme: dark) {
353
- :root{
354
- --bg-fill-om: #263238; --bg-fill-im: #1E2930; --bg-stroke: #455A64; --muted: #37474F;
355
- --hl-stroke: #FF8A80; --hl-fill: #3E2723; --hl-text: #FFCDD2; --hl-dot: #FF8A80;
356
- --bg-text: #ECEFF1; --bg-line: #37474F; --bg-dot: #546E7A;
357
- }
358
- }
359
-
360
- .download-btn{ padding:8px 12px; border-radius:6px; border:1px solid var(--bg-line); background:transparent; cursor:pointer; }
361
- .download-btn:hover{ box-shadow:0 2px 8px rgba(0,0,0,0.08); }
362
-
363
- /* Keep Gradio panels tidy */
364
- .gradio-container{ max-width:1100px; margin:0 auto; }
365
  """
366
 
367
- # Use Gradio theme but also inject Tailwind CDN for utility classes in HTML
368
  theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
369
 
370
- gr_tailwind = """
371
- <link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
372
- """
373
-
374
- with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok") as app:
375
- # Inject Tailwind (works in Gradio HTML scope)
376
- gr.HTML(gr_tailwind)
377
-
378
- gr.HTML("""
379
- <div class="w-full p-4 rounded-lg" style="background:linear-gradient(to right,#E0F7FA,#E1F5FE); border:1px solid #B3E5FC; text-align:center;">
380
- <h1 style="font-family:Inter, Arial; font-size:28px; margin:0; color:#0288D1;">LocPred-Prok</h1>
381
- <div style="color:#0277BD; margin-top:6px;">Deep Learning Framework for Prokaryotic Subcellular Localization</div>
382
- </div>
383
- """)
384
-
385
  with gr.Row():
386
- with gr.Column(elem_classes="panel-card", scale=6):
387
- gr.Markdown("<div class='panel-header'><span class='panel-label'>A</span>Sequence Input</div>")
388
- sequence_input = gr.Textbox(lines=8, show_label=False, placeholder=">Sequence... (single-letter AA)")
389
  with gr.Row():
390
  clear_btn = gr.ClearButton(sequence_input, value="Clear")
391
  submit_btn = gr.Button("Predict Analysis", variant="primary")
392
  with gr.Row():
393
- layout_choice = gr.Radio(['circular', 'horizontal'], value='circular', label='Diagram Layout', info='Choose circular (default) or horizontal layout for the cell diagram')
394
- high_res_flag = gr.Checkbox(label='High resolution render (larger SVG)', value=False)
395
- gr.Examples([[">Outer Membrane\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAF..." ]], inputs=sequence_input, label=None)
396
-
397
- with gr.Column(elem_classes="panel-card", scale=6):
398
- gr.Markdown("<div class='panel-header'><span class='panel-label'>B</span>Localization Visualization</div>")
 
 
399
  output_svg = gr.HTML(label="Visual", show_label=False)
400
 
401
  with gr.Row():
402
- with gr.Column(elem_classes="panel-card", scale=6):
403
- gr.Markdown("<div class='panel-header'><span class='panel-label'>C</span>Prediction Confidence</div>")
404
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
405
- with gr.Column(elem_classes="panel-card", scale=6):
406
- gr.Markdown("<div class='panel-header'><span class='panel-label'>D</span>Learned Attention Heatmap</div>")
407
- output_plot = gr.Plot(label="Attention", show_label=False)
408
 
409
- submit_btn.click(fn=predict, inputs=[sequence_input, layout_choice, high_res_flag], outputs=[output_label, output_svg, output_plot])
410
  clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot])
411
 
412
- # ---------- Launch ----------
413
- if __name__ == '__main__':
414
- app.launch()
415
-
416
-
417
- # --- END OF FILE ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app_locpred_prok.py
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  import os
 
3
  import re
4
+ import json
5
  import uuid
6
+ import io
7
+ import shutil
8
+ import base64
9
+ from typing import Tuple
10
+
11
  import torch
12
  import torch.nn as nn
13
  import torch.nn.functional as F
 
16
  import numpy as np
17
  from transformers import AutoTokenizer, AutoModel
18
 
19
+ # Optional server-side PDF export dependency
20
+ try:
21
+ import cairosvg
22
+ CAIROSVG_AVAILABLE = True
23
+ except Exception:
24
+ CAIROSVG_AVAILABLE = False
25
+
26
+ # ========== Environment (same cache handling as before) ==========
27
  plt.switch_backend('Agg')
28
  os.environ["HF_HOME"] = "/tmp/hf_cache"
29
  os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
30
  os.environ["HF_HUB_DISABLE_SYMLINKS_WARNING"] = "1"
31
 
 
32
  for path in ["/tmp/hf_cache", os.path.expanduser("~/.cache/huggingface")]:
33
  shutil.rmtree(path, ignore_errors=True)
34
  os.makedirs(path, exist_ok=True)
35
 
36
+ # ========== Model architecture (same as you had) ==========
37
  class AttentionPooling(nn.Module):
38
  def __init__(self, d_model):
39
  super().__init__()
40
  self.attention_net = nn.Linear(d_model, 1)
41
 
42
  def forward(self, x, mask):
43
+ attn_logits = self.attention_net(x).squeeze(2)
 
44
  attn_logits = attn_logits.masked_fill(mask == 0, -1e9)
45
  attn_weights = F.softmax(attn_logits, dim=1)
46
  return torch.bmm(attn_weights.unsqueeze(1), x).squeeze(1), attn_weights
 
54
  self.tok_projector = nn.Linear(d_model, projection_dim)
55
  fused_dim = projection_dim * 2
56
  self.gate = nn.Sequential(nn.Linear(fused_dim, fused_dim), nn.Sigmoid())
57
+ self.classifier_head = nn.Sequential(nn.LayerNorm(fused_dim),
58
+ nn.Linear(fused_dim, fused_dim * 2),
59
+ nn.ReLU(),
60
+ nn.Dropout(dropout),
61
+ nn.Linear(fused_dim * 2, num_classes))
62
 
63
  def forward(self, cls_embedding, token_embeddings, mask):
64
  z_cls = self.cls_projector(cls_embedding)
 
71
  z_fused_gated = z_fused_concat * gate_values
72
  return self.classifier_head(z_fused_gated), pooling_weights
73
 
74
+ # ========== Load models (keep same variable names so your logic works) ==========
75
  DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
  PLM_MODEL_NAME = "facebook/esm2_t30_150M_UR50D"
77
  CLASSIFIER_PATH = "best_model_esm2_t30_150M_UR50D.pth"
78
  LABEL_MAP_PATH = "label_map.json"
79
 
80
+ # If you want to test UI without model files, set MOCK_MODE = True
81
+ MOCK_MODE = False
82
+
83
+ if not MOCK_MODE:
84
+ if not os.path.exists(LABEL_MAP_PATH):
85
+ raise FileNotFoundError(f"Missing {LABEL_MAP_PATH}")
86
+ if not os.path.exists(CLASSIFIER_PATH):
87
+ raise FileNotFoundError(f"Missing {CLASSIFIER_PATH}")
88
+
89
+ with open(LABEL_MAP_PATH, 'r') as f:
90
+ label_to_idx = json.load(f)
91
+ idx_to_label = {v: k for k, v in label_to_idx.items()}
92
+ NUM_CLASSES = len(idx_to_label)
93
+ D_MODEL = 640
94
+
95
+ print("🔹 Loading models...")
96
+ tokenizer = AutoTokenizer.from_pretrained(PLM_MODEL_NAME)
97
+ plm_model = AutoModel.from_pretrained(PLM_MODEL_NAME).to(DEVICE).eval()
98
+ classifier = ProtDualBranchEnhancedClassifier(D_MODEL, 32, NUM_CLASSES, 0.3, 3).to(DEVICE)
99
+ classifier.load_state_dict(torch.load(CLASSIFIER_PATH, map_location=DEVICE))
100
+ classifier.eval()
101
+ print("✅ Models loaded.")
102
+ else:
103
+ # Mock objects for UI development
104
+ idx_to_label = {0: "Cytoplasm", 1: "Inner Membrane", 2: "Periplasm", 3: "Outer Membrane", 4: "Cell Wall", 5: "Extracellular"}
105
+ NUM_CLASSES = len(idx_to_label)
106
+ tokenizer = None
107
+ plm_model = None
108
+ classifier = None
109
+
110
+ # ========== SVG generator (UniProt-like) ==========
111
+ def generate_uniprot_style_svg(pred_label: str,
112
+ gram: str = "negative",
113
+ theme: str = "uniprot-blue",
114
+ layout: str = "circular",
115
+ high_res: bool = False,
116
+ uid: str = None) -> str:
117
+ """
118
+ Create a UniProt-like bacterial localization diagram:
119
+ - pred_label: predicted top class name (used to highlight)
120
+ - gram: 'negative' or 'positive'
121
+ - theme: 'uniprot-blue', 'red-highlight', 'auto' (auto uses prefers-color-scheme)
122
+ - layout: 'circular' (capsule) or 'horizontal' (cell left, labels right)
123
+ - high_res: True -> larger viewBox for higher-quality PNG/PDF
124
+ Returns: HTML string containing responsive SVG and download buttons.
125
+ """
126
+
127
+ target = pred_label.lower() if pred_label else ""
128
+ is_active = {
129
+ "sec": ("extracellular" in target) or ("secreted" in target),
130
+ "om": ("outer membrane" in target),
131
+ "peri": ("periplasm" in target),
132
+ "cw": ("cell wall" in target),
133
+ "im": ("inner membrane" in target) or ("plasma membrane" in target),
134
+ "cyto": ("cytoplasm" in target) or ("cytosol" in target)
135
  }
136
 
137
+ # If gram-positive, there is no outer membrane and cell wall is thicker
138
+ have_outer_membrane = (gram == "negative")
139
+
140
+ # color themes using CSS variables (supports prefers-color-scheme)
141
+ css_vars = {
142
+ "uniprot-blue": {
143
+ "--om-fill": "#F5F7FA", "--im-fill": "#FFFFFF", "--stroke": "#607D8B",
144
+ "--muted": "#B0BEC5", "--text": "#263238", "--highlight": "#0288D1"
145
+ },
146
+ "red-highlight": {
147
+ "--om-fill": "#FFEBEE", "--im-fill": "#FFFFFF", "--stroke": "#607D8B",
148
+ "--muted": "#B0BEC5", "--text": "#263238", "--highlight": "#D32F2F"
149
+ }
150
+ }
151
+
152
+ selected = css_vars.get(theme if theme in css_vars else "uniprot-blue")
153
+
154
+ # Unique IDs for DOM elements to allow multiple diagrams on page
155
+ uid = uid or str(uuid.uuid4()).replace("-", "")[:10]
156
+ svg_id = f"loc_svg_{uid}"
157
+
158
+ # Sizes
159
+ if high_res:
160
+ W, H = 1600, 800
 
 
 
 
 
 
161
  else:
162
+ W, H = 800, 420
163
+
164
+ # anchor coords and label positions (tuned for viewBox)
165
+ center_x = int(W * 0.38)
166
+ center_y = int(H * 0.5)
167
+ label_x = int(W * 0.76)
168
+ label_x_left = int(W * 0.58)
169
+
170
+ label_y_map = {
171
+ "sec": int(H * 0.12),
172
+ "om": int(H * 0.22),
173
+ "peri": int(H * 0.32),
174
+ "cw": int(H * 0.42),
175
+ "im": int(H * 0.62),
176
+ "cyto": int(H * 0.78)
177
  }
178
 
179
+ anchors = {
180
+ "sec": (center_x + 160, center_y - 160),
181
+ "om": (center_x + 160, center_y - 100),
182
+ "peri":(center_x + 140, center_y - 40),
183
+ "cw": (center_x + 120, center_y + 10),
184
+ "im": (center_x + 80, center_y + 60),
185
+ "cyto":(center_x + 20, center_y + 70)
186
  }
187
 
188
+ # Build helper for connector group
189
+ def connector_svg(key, text):
190
+ ex, ey = anchors[key]
191
+ tx, ty = label_x, label_y_map[key]
192
+ # styling depends on active
193
+ active = is_active.get(key, False)
194
+ stroke_color = selected["--highlight"] if active else selected["--muted"]
195
+ stroke_w = "2.6" if active else "1.4"
196
+ fontw = "700" if active else "400"
197
+ dot_r = "6" if active else "4"
198
+ path = f"M {tx-20} {ty-6} C {tx-100} {ty-6}, {ex+60} {ey+10}, {ex} {ey}"
199
+ return f"""
200
+ <g class="connector connector-{key}">
201
+ <text x="{tx}" y="{ty}" fill="{selected['--text']}" font-weight="{fontw}" font-size="{14 if not high_res else 22}" font-family="Inter, Arial">{text}</text>
202
+ <path d="{path}" fill="none" stroke="{stroke_color}" stroke-width="{stroke_w}" stroke-linecap="round" stroke-linejoin="round" />
203
+ <circle cx="{ex}" cy="{ey}" r="{dot_r}" fill="{stroke_color}" stroke="white" stroke-width="{1 if not high_res else 1.6}" />
204
+ </g>
205
+ """
206
+
207
+ # Build cell shapes: capsule-like curves - simpler parametric shapes
208
+ # Outer membrane (or single outer layer for gram-positive), cell wall, inner membrane
209
+ # Different shapes when gram-positive: thicker cell wall, no outer membrane ring.
210
+ # We'll draw with bezier-like path strings tuned to look UniProt-ish.
211
+ if have_outer_membrane:
212
+ om_fill = "var(--om-fill)"
213
+ om_stroke = "var(--stroke)"
214
+ cw_fill = "none"
215
+ cw_stroke = "var(--muted)"
216
+ im_fill = "var(--im-fill)"
217
+ im_stroke = "var(--stroke)"
218
+ else:
219
+ # gram-positive: cell wall thicker and outer membrane absent
220
+ om_fill = "none"
221
+ om_stroke = "none"
222
+ cw_fill = "var(--om-fill)"
223
+ cw_stroke = "var(--stroke)"
224
+ im_fill = "var(--im-fill)"
225
+ im_stroke = "var(--stroke)"
226
+
227
+ # layer highlight override if active
228
+ def stroke_override(key, base):
229
+ if is_active.get(key, False):
230
+ return selected["--highlight"]
231
+ return base
232
+
233
+ # inline CSS for animations and hover effects
234
+ svg_style = f"""
235
+ <style>
236
+ /* theme vars */
237
+ :root {{
238
+ --om-fill: {selected['--om-fill']};
239
+ --im-fill: {selected['--im-fill']};
240
+ --stroke: {selected['--stroke']};
241
+ --muted: {selected['--muted']};
242
+ --text: {selected['--text']};
243
+ --highlight: {selected['--highlight']};
244
+ }}
245
+ @media (prefers-color-scheme: dark) {{
246
+ :root {{
247
+ --om-fill: #28343a;
248
+ --im-fill: #1f2b30;
249
+ --stroke: #90a4ae;
250
+ --muted: #546e7a;
251
+ --text: #e0f2f1;
252
+ }}
253
+ }}
254
+
255
+ /* connector hover: slightly thicken the path and enlarge dot */
256
+ .connector path {{ transition: stroke-width 180ms ease, stroke 180ms ease; opacity:0.95; }}
257
+ .connector circle {{ transition: r 160ms ease, transform 160ms ease; transform-origin: center; }}
258
+ .connector text {{ transition: fill 160ms ease; }}
259
+
260
+ /* on hover of group, emphasize */
261
+ .connector:hover path {{ stroke-width: calc(var(--hover-w, 3)); opacity:1; filter: drop-shadow(0 2px 2px rgba(0,0,0,0.06)); }}
262
+ .connector:hover circle {{ transform: scale(1.25); }}
263
+ /* subtle floating animation for lines */
264
+ .connector path {{}}
265
+ @keyframes floatx {{ 0% {{ transform: translateX(0px); }} 50% {{ transform: translateX(1px); }} 100% {{ transform: translateX(0px); }} }}
266
+ .connector path {{ animation: floatx 4s ease-in-out infinite; animation-delay: calc(var(--i, 0) * 0.12s); opacity:0.95; }}
267
+
268
+ /* make the whole svg responsive */
269
+ svg {{ max-width: 100%; height: auto; display:block; }}
270
+
271
+ /* layer highlight when active: add glow */
272
+ .layer-active {{ filter: drop-shadow(0 4px 8px rgba(0,0,0,0.08)); }}
273
+ </style>
274
+ """
275
+
276
+ # Compose SVG core shapes (simplified, but tuned coordinates)
277
+ # We use path shapes with translated center for convenience.
278
+ cell_shapes = ""
279
+ # Outer membrane / envelope
280
+ if have_outer_membrane:
281
+ cell_shapes += f'''
282
+ <g id="outer_membrane" class="layer {'layer-active' if is_active['om'] else ''}">
283
+ <ellipse cx="{center_x}" cy="{center_y}" rx="{220 if not high_res else 440}" ry="{170 if not high_res else 340}"
284
+ fill="var(--om-fill)" stroke="{stroke_override('om', 'var(--stroke)')}" stroke-width="{3 if is_active['om'] else 2}"/>
285
  </g>
286
  '''
287
+ # Cell wall (dashed)
288
+ cell_shapes += f'''
289
+ <g id="cell_wall">
290
+ <ellipse cx="{center_x}" cy="{center_y}" rx="{190 if not high_res else 380}" ry="{150 if not high_res else 300}"
291
+ fill="none" stroke="{stroke_override('cw','var(--muted)')}" stroke-width="{4 if is_active['cw'] else 2}" stroke-dasharray="10 6"/>
292
+ </g>
293
+ '''
294
+ # inner membrane
295
+ cell_shapes += f'''
296
+ <g id="inner_membrane" class="layer {'layer-active' if is_active['im'] else ''}">
297
+ <ellipse cx="{center_x}" cy="{center_y}" rx="{140 if not high_res else 280}" ry="{100 if not high_res else 200}"
298
+ fill="var(--im-fill)" stroke="{stroke_override('im','var(--stroke)')}" stroke-width="{3 if is_active['im'] else 1.8}"/>
299
+ </g>
300
+ '''
301
+ else:
302
+ # Gram positive: thick cell wall as filled ellipse + inner membrane
303
+ cell_shapes += f'''
304
+ <g id="cell_wall_gp" class="layer {'layer-active' if is_active['cw'] else ''}">
305
+ <ellipse cx="{center_x}" cy="{center_y}" rx="{230 if not high_res else 460}" ry="{180 if not high_res else 360}"
306
+ fill="{selected['--om-fill']}" stroke="{stroke_override('cw','var(--stroke)')}" stroke-width="{3 if is_active['cw'] else 2}"/>
307
+ </g>
308
+ <g id="inner_membrane" class="layer {'layer-active' if is_active['im'] else ''}">
309
+ <ellipse cx="{center_x}" cy="{center_y}" rx="{150 if not high_res else 300}" ry="{110 if not high_res else 220}"
310
+ fill="var(--im-fill)" stroke="{stroke_override('im','var(--stroke)')}" stroke-width="{2 if is_active['im'] else 1.4}"/>
311
  </g>
312
+ '''
313
+
314
+ # cytoplasm ornament
315
+ cell_shapes += f'''
316
+ <g id="cytoplasm_wiggles" opacity="0.65">
317
+ <path d="M {center_x-60} {center_y+10} q 30 -50 70 0 q 30 50 70 0" stroke="var(--muted)" stroke-width="6" fill="none" stroke-linecap="round"/>
318
+ <circle cx="{center_x-40}" cy="{center_y+40}" r="{3 if not high_res else 6}" fill="var(--muted)"/>
319
+ <circle cx="{center_x+20}" cy="{center_y+50}" r="{3 if not high_res else 6}" fill="var(--muted)"/>
320
  </g>
321
  '''
322
 
323
+ # connectors
324
+ connectors = ""
325
+ connectors += connector_svg("sec", "Extracellular / Secreted")
326
+ # outer membrane only if present
327
+ if have_outer_membrane:
328
+ connectors += connector_svg("om", "Outer Membrane")
329
+ connectors += connector_svg("peri", "Periplasm")
330
+ connectors += connector_svg("cw", "Cell Wall")
331
+ connectors += connector_svg("im", "Inner Membrane")
332
+ connectors += connector_svg("cyto", "Cytoplasm")
333
+
334
+ # Build download buttons and client-side JS to download SVG and PNG
335
+ # PDF export will call a Gradio server endpoint (provided below)
336
+ html = f"""
337
+ <div style="width:100%; text-align:center;">
338
+ {svg_style}
339
+ <svg id="{svg_id}" viewBox="0 0 {W} {H}" xmlns="http://www.w3.org/2000/svg" role="img" aria-label="Bacterial localization diagram">
340
  <defs>
341
  <style><![CDATA[
342
  text {{ font-family: Inter, Arial, sans-serif; }}
343
  ]]></style>
344
  </defs>
345
+ {cell_shapes}
346
  {connectors}
347
+ </svg>
348
+
349
+ <div style="margin-top:8px; display:flex; gap:8px; justify-content:center; align-items:center;">
350
+ <button id="download_svg_{uid}" class="download-btn">Download SVG</button>
351
+ <button id="download_png_{uid}" class="download-btn">Download PNG</button>
352
+ <button id="download_pdf_{uid}" class="download-btn">Download PDF</button>
353
+ <div style="font-size:12px; color:var(--text); align-self:center;">{gram.title()} · {layout.title()} {'· High-res' if high_res else ''}</div>
 
 
 
 
 
 
 
 
 
 
354
  </div>
355
+ </div>
356
+
357
  <script>
358
  (function(){{
359
+ const svgEl = document.getElementById("{svg_id}");
360
+ const btnSvg = document.getElementById("download_svg_{uid}");
361
+ const btnPng = document.getElementById("download_png_{uid}");
362
+ const btnPdf = document.getElementById("download_pdf_{uid}");
363
+
364
+ function downloadFile(filename, blob) {{
365
+ const url = URL.createObjectURL(blob);
366
+ const a = document.createElement('a');
367
+ a.href = url; a.download = filename; document.body.appendChild(a); a.click();
368
+ setTimeout(()=>{{ URL.revokeObjectURL(url); a.remove(); }}, 200);
369
+ }}
370
+
371
+ btnSvg.addEventListener('click', ()=>{{
372
+ const serializer = new XMLSerializer();
373
+ let source = serializer.serializeToString(svgEl);
374
+ if(!source.match(/^<svg[^>]+xmlns="http:\\/\\/www.w3.org\\/2000\\/svg"/)) {{
375
+ source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
376
  }}
377
+ const blob = new Blob([source], {{type: 'image/svg+xml;charset=utf-8'}});
378
+ downloadFile('locpred_diagram.svg', blob);
379
+ }});
380
+
381
+ btnPng.addEventListener('click', ()=>{{
382
+ const serializer = new XMLSerializer();
383
+ let source = serializer.serializeToString(svgEl);
384
+ if(!source.match(/^<svg[^>]+xmlns="http:\\/\\/www.w3.org\\/2000\\/svg"/)) {{
385
+ source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
386
+ }}
387
+ const svgBlob = new Blob([source], {{type: 'image/svg+xml;charset=utf-8'}});
388
+ const url = URL.createObjectURL(svgBlob);
389
+ const img = new Image();
390
+ img.onload = function() {{
391
+ const canvas = document.createElement('canvas');
392
+ // scale 2x for higher quality
393
+ const scale = 2;
394
+ canvas.width = img.width * scale;
395
+ canvas.height = img.height * scale;
396
+ const ctx = canvas.getContext('2d');
397
+ // optional white background
398
+ ctx.fillStyle = "white";
399
+ ctx.fillRect(0,0,canvas.width,canvas.height);
400
+ ctx.drawImage(img, 0, 0, canvas.width, canvas.height);
401
+ canvas.toBlob(function(blob) {{
402
+ downloadFile('locpred_diagram.png', blob);
403
+ }}, 'image/png');
404
+ URL.revokeObjectURL(url);
405
+ }};
406
+ img.onerror = function(e) {{
407
+ alert('Failed to render PNG in your browser.');
408
+ URL.revokeObjectURL(url);
409
+ }};
410
+ img.src = url;
411
+ }});
412
+
413
+ btnPdf.addEventListener('click', async ()=>{{
414
+ // send the SVG string to the server /gradio route for PDF conversion
415
+ const serializer = new XMLSerializer();
416
+ let source = serializer.serializeToString(svgEl);
417
+ if(!source.match(/^<svg[^>]+xmlns="http:\\/\\/www.w3.org\\/2000\\/svg"/)) {{
418
+ source = source.replace(/^<svg/, '<svg xmlns="http://www.w3.org/2000/svg"');
419
+ }}
420
+ // call Gradio server function via fetch to /convert_svg_to_pdf (provided below)
421
+ try {{
422
+ const resp = await fetch('/convert_svg_to_pdf', {{
423
+ method: 'POST',
424
+ headers: {{ 'Content-Type': 'application/json' }},
425
+ body: JSON.stringify({{ svg: source }})
426
+ }});
427
+ if(!resp.ok) {{
428
+ const txt = await resp.text();
429
+ alert('PDF conversion failed: ' + txt);
430
+ return;
431
+ }}
432
+ const blob = await resp.blob();
433
+ downloadFile('locpred_diagram.pdf', blob);
434
+ }} catch (err) {{
435
+ alert('PDF conversion failed: ' + err);
436
+ }}
437
+ }});
438
  }})();
439
  </script>
440
+ """
441
+
442
  return html
443
 
444
+ # ========== Heatmap (same as before) ==========
445
  def draw_attention_heatmap_strip(weights, sequence):
446
  if weights.max() > 0:
447
  weights = (weights - weights.min()) / (weights.max() - weights.min())
448
  data = weights.reshape(1, -1)
449
  fig, ax = plt.subplots(figsize=(8, 1.5), dpi=150)
450
  im = ax.imshow(data, cmap='Reds', aspect='auto', vmin=0, vmax=1)
451
+ ax.set_title('Sequence Attention Heatmap (High Color = Key Feature)', fontsize=10, fontweight='bold', color='#37474F', pad=6)
452
  ax.set_xlabel('Residue Position', fontsize=9)
453
  ax.set_yticks([])
454
  cbar = plt.colorbar(im, ax=ax, orientation='vertical', fraction=0.02, pad=0.02)
 
459
  plt.tight_layout()
460
  return fig
461
 
462
+ # ========== Prediction function ==========
463
+ def predict(sequence_input: str, gram_choice: str, theme_choice: str, layout_choice: str, high_res_flag: bool):
464
+ """
465
+ Returns:
466
+ - confidences dict (for Label)
467
+ - svg html string (for HTML output)
468
+ - attention heatmap figure (for Plot)
469
+ """
470
  if not sequence_input or sequence_input.isspace():
471
+ raise gr.Error("Empty Input")
472
+
473
  seq = "".join(sequence_input.split('\n')[1:]) if sequence_input.startswith('>') else sequence_input
474
  seq = re.sub(r'[^A-Z]', '', seq.upper())[:1024]
475
  if not seq:
476
+ raise gr.Error("Invalid Sequence")
477
+
478
+ if MOCK_MODE:
479
+ # mock probabilities
480
+ probs = torch.softmax(torch.randn(NUM_CLASSES), dim=0)
481
+ logits = probs
482
+ pooling_weights = np.abs(np.random.randn(len(seq))).astype(float)
483
+ pooling_weights = pooling_weights / pooling_weights.sum()
484
+ top_label = idx_to_label[int(torch.argmax(probs).item())]
485
+ else:
486
+ with torch.no_grad():
487
+ inputs = tokenizer(seq, return_tensors="pt", truncation=True, max_length=1024).to(DEVICE)
488
+ outputs = plm_model(**inputs)
489
+ hidden_states = outputs.last_hidden_state
490
+ cls_embedding = hidden_states[:, 0, :]
491
+ token_embeddings = hidden_states[:, 1:-1, :]
492
+ token_mask = inputs['attention_mask'][:, 1:-1]
493
+ logits, pooling_weights = classifier(cls_embedding, token_embeddings, token_mask)
494
+ probs = F.softmax(logits, dim=1)[0]
495
+ top_label = idx_to_label[torch.argmax(probs).item()]
496
+ pooling_weights = pooling_weights[0].cpu().numpy()
497
+
498
+ confidences = { idx_to_label[i]: float(p) for i,p in enumerate(probs) } if not MOCK_MODE else { idx_to_label[i]: float(p) for i,p in enumerate(probs)}
499
+ svg_html = generate_uniprot_style_svg(top_label, gram=gram_choice, theme=theme_choice, layout=layout_choice, high_res=high_res_flag)
500
+ heatmap_fig = draw_attention_heatmap_strip(np.array(pooling_weights), seq)
501
+
502
+ return confidences, svg_html, heatmap_fig
503
+
504
+ # ========== Server-side PDF conversion endpoint for Gradio ==========
505
+ # This function will be exposed at /convert_svg_to_pdf when app launches.
506
+ def convert_svg_to_pdf_endpoint(svg_str: str):
507
+ """
508
+ Convert an SVG string to PDF bytes using cairosvg (if available).
509
+ Return bytes-like object (PDF) or raise error.
510
+ """
511
+ if not CAIROSVG_AVAILABLE:
512
+ raise RuntimeError("Server-side PDF conversion requires 'cairosvg' package. Install with: pip install cairosvg")
513
+
514
+ # cairosvg.svg2pdf can take bytes or string
515
+ pdf_bytes = cairosvg.svg2pdf(bytestring=svg_str.encode('utf-8'))
516
+ return ("locpred_diagram.pdf", pdf_bytes)
517
+
518
+ # ========== UI layout (Gradio) ==========
519
  layout_css = """
520
+ /* small customizations and CSS variables fallback */
521
+ :root { --panel-bg: white; }
522
+ .gradio-container { max-width: 1200px; margin: 0 auto; }
523
+ .download-btn { padding:8px 12px; border-radius:6px; border:1px solid #D1D5DB; background:transparent; cursor:pointer; }
524
+ .download-btn:hover { box-shadow: 0 4px 14px rgba(16,24,40,0.08); }
525
+ .panel-card { border:1px solid #e6eef5; border-radius:8px; padding:12px; background:var(--panel-bg); }
526
+ .panel-header { font-weight:700; color:#475569; border-bottom:2px solid #f1f5f9; padding-bottom:8px; margin-bottom:10px; }
 
 
 
 
 
 
 
 
 
 
 
 
527
  """
528
 
 
529
  theme = gr.themes.Soft(primary_hue="sky").set(body_background_fill="white", block_background_fill="white", block_border_width="0px")
530
 
531
+ with gr.Blocks(theme=theme, css=layout_css, title="LocPred-Prok (UniProt-style)") as app:
532
+ gr.Markdown("<div style='font-size:22px; font-weight:800; color:#0288D1;'>LocPred-Prok — UniProt-style visualization</div>")
 
 
 
 
 
 
 
 
 
 
 
 
 
533
  with gr.Row():
534
+ with gr.Column(scale=6):
535
+ gr.Markdown("<div class='panel-header'><span style='background:#E0F7FA;color:#0277BD;padding:3px 6px;border-radius:4px;font-weight:800;margin-right:8px;'>A</span>Sequence Input</div>")
536
+ sequence_input = gr.Textbox(lines=8, placeholder=">Sequence (single-letter amino acids) or paste raw sequence", show_label=False)
537
  with gr.Row():
538
  clear_btn = gr.ClearButton(sequence_input, value="Clear")
539
  submit_btn = gr.Button("Predict Analysis", variant="primary")
540
  with gr.Row():
541
+ gram_choice = gr.Radio(choices=["negative", "positive"], value="negative", label="Gram type")
542
+ theme_choice = gr.Radio(choices=["uniprot-blue", "red-highlight"], value="uniprot-blue", label="Color Theme")
543
+ layout_choice = gr.Radio(choices=["circular", "horizontal"], value="circular", label="Diagram Layout")
544
+ high_res_flag = gr.Checkbox(value=False, label="High resolution (bigger SVG/PDF)")
545
+ gr.Examples([[">Outer Membrane\nAPKNTWYTGAKLGWSQYHDTGFINNNGPTHENQLGAGAF..."]], inputs=sequence_input)
546
+
547
+ with gr.Column(scale=6):
548
+ gr.Markdown("<div class='panel-header'><span style='background:#E0F7FA;color:#0277BD;padding:3px 6px;border-radius:4px;font-weight:800;margin-right:8px;'>B</span>Localization Visualization</div>")
549
  output_svg = gr.HTML(label="Visual", show_label=False)
550
 
551
  with gr.Row():
552
+ with gr.Column(scale=6):
553
+ gr.Markdown("<div class='panel-header'><span style='background:#E0F7FA;color:#0277BD;padding:3px 6px;border-radius:4px;font-weight:800;margin-right:8px;'>C</span>Prediction Confidence</div>")
554
  output_label = gr.Label(num_top_classes=NUM_CLASSES, show_label=False)
555
+ with gr.Column(scale=6):
556
+ gr.Markdown("<div class='panel-header'><span style='background:#E0F7FA;color:#0277BD;padding:3px 6px;border-radius:4px;font-weight:800;margin-right:8px;'>D</span>Attention Heatmap</div>")
557
+ output_plot = gr.Plot(show_label=False)
558
 
559
+ submit_btn.click(fn=predict, inputs=[sequence_input, gram_choice, theme_choice, layout_choice, high_res_flag], outputs=[output_label, output_svg, output_plot])
560
  clear_btn.click(lambda: [None, None, None], outputs=[output_label, output_svg, output_plot])
561
 
562
+ # Expose PDF conversion endpoint: Gradio allows adding a separate route handler via app.launch later.
563
+ # We'll attach the endpoint to the FastAPI app used by Gradio when launching.
564
+
565
+ # ========== Run server with custom route for PDF conversion ==========
566
+ if __name__ == "__main__":
567
+ from fastapi import FastAPI, Request, Response
568
+ import uvicorn
569
+
570
+ # Build Gradio app to get underlying FastAPI instance
571
+ demo = app
572
+ # Get the underlying FastAPI app (gradio >= 3.0)
573
+ # When launching, we'll mount a custom route /convert_svg_to_pdf handled by convert_svg_to_pdf_endpoint
574
+ # Gradio's launch will create a FastAPI object; to avoid internal changes, we use the `server_name` arg.
575
+
576
+ # Create a lightweight FastAPI for the PDF endpoint and mount the Gradio interface into it
577
+ fast_app = FastAPI()
578
+
579
+ @fast_app.post("/convert_svg_to_pdf")
580
+ async def convert_svg_to_pdf_api(request: Request):
581
+ payload = await request.json()
582
+ svg = payload.get("svg", None)
583
+ if not svg:
584
+ return Response(content="No svg provided", status_code=400)
585
+ if not CAIROSVG_AVAILABLE:
586
+ return Response(content="Server-side PDF conversion unavailable: install 'cairosvg' in the server environment.", status_code=501)
587
+ try:
588
+ pdf_bytes = cairosvg.svg2pdf(bytestring=svg.encode('utf-8'))
589
+ return Response(content=pdf_bytes, media_type="application/pdf")
590
+ except Exception as e:
591
+ return Response(content=f"PDF conversion error: {e}", status_code=500)
592
+
593
+ # Mount the Gradio interface at root
594
+ gr.mount_gradio_app(fast_app, demo, path="/")
595
+
596
+ # Launch uvicorn with the FastAPI app
597
+ uvicorn.run(fast_app, host="0.0.0.0", port=7860, log_level="info")