Ninjani commited on
Commit
2da4f4c
·
1 Parent(s): 40c07c2
Files changed (1) hide show
  1. app.py +26 -47
app.py CHANGED
@@ -3,16 +3,11 @@ import time
3
 
4
  import torch
5
  import gradio as gr
6
- import numpy as np
7
  from loguru import logger
8
 
9
  from stoic.model import Stoic
10
 
11
- CHAIN_COLORS = [
12
- "#636EFA", "#EF553B", "#00CC96", "#AB63FA", "#FFA15A",
13
- "#19D3F3", "#FF6692", "#B6E880", "#FF97FF", "#FECB52",
14
- ]
15
-
16
 
17
  @functools.lru_cache(maxsize=1)
18
  def get_model():
@@ -46,13 +41,15 @@ def predict(sequences_text: str, top_n: int, return_weights: bool):
46
 
47
  chain_labels = [chr(ord("A") + i) for i in range(len(sequences))]
48
 
49
- header = "| Rank | " + " | ".join(f"Chain {l}" for l in chain_labels) + " | Stoichiometry |"
50
- separator = "|------|" + "|".join("-----" for _ in chain_labels) + "|---------------|"
51
  rows = []
52
  for rank, candidate in enumerate(results, 1):
53
  copies = [candidate.get(seq, 0) for seq in sequences]
54
  stoich = "".join(f"{l}<sub>{c}</sub>" for l, c in zip(chain_labels, copies))
55
- row = f"| {rank} | " + " | ".join(str(c) for c in copies) + f" | {stoich} |"
 
 
56
  rows.append(row)
57
 
58
  table = "\n".join([header, separator] + rows)
@@ -65,54 +62,28 @@ def predict(sequences_text: str, top_n: int, return_weights: bool):
65
  stoich_md = table + "\n".join(legend_lines)
66
 
67
  if return_weights:
68
- fig = build_weights_plot(residue_predictions, chain_labels)
69
- return stoich_md, gr.update(value=fig, visible=True), f"{elapsed:.2f}s"
70
 
71
  return stoich_md, gr.update(value=None, visible=False), f"{elapsed:.2f}s"
72
 
73
 
74
- def build_weights_plot(residue_predictions, chain_labels):
75
- import plotly.graph_objects as go
76
- from plotly.subplots import make_subplots
77
-
78
  pred_residues = residue_predictions["pred_residues"]
79
  attention_mask = residue_predictions["attention_mask"]
80
  seqs = residue_predictions["sequences"]
81
- n_chains = len(seqs)
82
-
83
- fig = make_subplots(
84
- rows=n_chains, cols=1,
85
- subplot_titles=[f"Chain {chain_labels[i]}" for i in range(n_chains)],
86
- vertical_spacing=0.12 / max(n_chains, 1),
87
- )
88
 
 
89
  for i, seq in enumerate(seqs):
90
  mask = attention_mask[i].astype(bool)
91
  weights = pred_residues[i][mask]
92
- residues = list(seq[:len(weights)])
93
- positions = list(range(1, len(weights) + 1))
94
- color = CHAIN_COLORS[i % len(CHAIN_COLORS)]
95
-
96
- fig.add_trace(
97
- go.Bar(
98
- x=positions,
99
- y=weights,
100
- marker_color=color,
101
- name=f"Chain {chain_labels[i]}",
102
- hovertemplate="Pos %{x}: %{customdata}<br>Weight: %{y:.4f}<extra></extra>",
103
- customdata=residues,
104
- ),
105
- row=i + 1, col=1,
106
- )
107
- fig.update_xaxes(title_text="Residue position", row=i + 1, col=1)
108
- fig.update_yaxes(title_text="Weight", row=i + 1, col=1)
109
-
110
- fig.update_layout(
111
- title="Residue-level Interface Prediction Weights",
112
- height=max(300 * n_chains, 400),
113
- showlegend=False,
114
- )
115
- return fig
116
 
117
 
118
  with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
@@ -144,7 +115,15 @@ with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
144
  results_output = gr.Markdown(value="Results will appear here.")
145
  run_time = gr.Textbox(label="Runtime")
146
 
147
- weights_plot = gr.Plot(label="Residue-level Interface Weights", visible=False)
 
 
 
 
 
 
 
 
148
 
149
  btn.click(
150
  predict,
 
3
 
4
  import torch
5
  import gradio as gr
6
+ import pandas as pd
7
  from loguru import logger
8
 
9
  from stoic.model import Stoic
10
 
 
 
 
 
 
11
 
12
  @functools.lru_cache(maxsize=1)
13
  def get_model():
 
41
 
42
  chain_labels = [chr(ord("A") + i) for i in range(len(sequences))]
43
 
44
+ header = "| Rank | " + " | ".join(f"Chain {l}" for l in chain_labels) + " | Stoichiometry | Score | Probability |"
45
+ separator = "|------|" + "|".join("-----" for _ in chain_labels) + "|---------------|-------|-------------|"
46
  rows = []
47
  for rank, candidate in enumerate(results, 1):
48
  copies = [candidate.get(seq, 0) for seq in sequences]
49
  stoich = "".join(f"{l}<sub>{c}</sub>" for l, c in zip(chain_labels, copies))
50
+ score = candidate.get("score", 0)
51
+ prob = candidate.get("probability", 0)
52
+ row = f"| {rank} | " + " | ".join(str(c) for c in copies) + f" | {stoich} | {score:.2f} | {prob:.2e} |"
53
  rows.append(row)
54
 
55
  table = "\n".join([header, separator] + rows)
 
62
  stoich_md = table + "\n".join(legend_lines)
63
 
64
  if return_weights:
65
+ df = build_weights_df(residue_predictions, chain_labels)
66
+ return stoich_md, gr.update(value=df, visible=True), f"{elapsed:.2f}s"
67
 
68
  return stoich_md, gr.update(value=None, visible=False), f"{elapsed:.2f}s"
69
 
70
 
71
+ def build_weights_df(residue_predictions, chain_labels):
 
 
 
72
  pred_residues = residue_predictions["pred_residues"]
73
  attention_mask = residue_predictions["attention_mask"]
74
  seqs = residue_predictions["sequences"]
 
 
 
 
 
 
 
75
 
76
+ records = []
77
  for i, seq in enumerate(seqs):
78
  mask = attention_mask[i].astype(bool)
79
  weights = pred_residues[i][mask]
80
+ for pos, w in enumerate(weights, 1):
81
+ records.append({
82
+ "Position": pos,
83
+ "Weight": float(w),
84
+ "Chain": f"Chain {chain_labels[i]}",
85
+ })
86
+ return pd.DataFrame(records)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
 
88
 
89
  with gr.Blocks(title="Stoic - Protein Stoichiometry Prediction") as app:
 
115
  results_output = gr.Markdown(value="Results will appear here.")
116
  run_time = gr.Textbox(label="Runtime")
117
 
118
+ weights_plot = gr.LinePlot(
119
+ x="Position",
120
+ y="Weight",
121
+ color="Chain",
122
+ x_title="Residue Position",
123
+ y_title="Weight",
124
+ label="Residue-level Interface Weights",
125
+ visible=False,
126
+ )
127
 
128
  btn.click(
129
  predict,