1-1-3-8 commited on
Commit
9e399de
·
verified ·
1 Parent(s): 0526b13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -267
app.py CHANGED
@@ -1,10 +1,7 @@
1
  import os
2
- import re
3
- import math
4
  import torch
5
  import gradio as gr
6
  from functools import lru_cache
7
- from collections import deque
8
  from transformers import (
9
  AutoTokenizer,
10
  AutoModelForCausalLM,
@@ -12,22 +9,8 @@ from transformers import (
12
  LogitsProcessorList,
13
  )
14
 
15
- # --------------------------- Config ---------------------------------
16
  MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
17
 
18
- # sampling / decoding knobs (tune as you like)
19
- TEMPERATURE = float(os.getenv("TEMP", "0.8"))
20
- TOP_P = float(os.getenv("TOP_P", "0.9"))
21
- DOT_BIAS = float(os.getenv("DOT_BIAS", "0.8")) # +logit added to '.'
22
- PAREN_RUN_PENALTY = float(os.getenv("PAREN_PEN", "0.5")) # -logit if long run
23
- PAREN_RUN_WINDOW = int(os.getenv("PAREN_WIN", "5")) # lookback tokens
24
-
25
- # RNA constraints
26
- ALLOW_GU = os.getenv("ALLOW_GU", "1") != "0"
27
- MIN_LOOP_LEN = int(os.getenv("MIN_LOOP", "3")) # minimum unpaired bases in loop
28
-
29
- # --------------------------------------------------------------------
30
-
31
  @lru_cache(maxsize=1)
32
  def _load_model_and_tokenizer():
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -40,25 +23,7 @@ def _load_model_and_tokenizer():
40
  model.eval()
41
  return tokenizer, model, device
42
 
43
- FEWSHOT = """Return ONLY the dot-bracket structure, one line, same length as RNA.
44
- RNA: GGGAAAUCCCU
45
- Dot-bracket: (((...)))).
46
- RNA: AUAUAUAU
47
- Dot-bracket: ........
48
- RNA: GGGAAACCC
49
- Dot-bracket: (((...)))
50
- RNA: GAAACUU
51
- Dot-bracket: (..())
52
- """
53
-
54
- def _make_prompt(seq: str) -> str:
55
- n = len(seq)
56
- return (
57
- FEWSHOT
58
- + f"\nRNA: {seq}\n"
59
- + f"Dot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
60
- )
61
-
62
  def _char_token_id(tokenizer, ch: str) -> int:
63
  ids = tokenizer.encode(ch, add_special_tokens=False)
64
  for tid in ids:
@@ -67,261 +32,148 @@ def _char_token_id(tokenizer, ch: str) -> int:
67
  for tid in range(len(tokenizer)):
68
  if tokenizer.decode([tid]) == ch:
69
  return tid
70
- raise ValueError(f"Could not find token id for char {ch!r}")
71
 
72
- # ---------- RNA pairing helpers ----------
73
- def _can_pair(a: str, b: str, allow_gu=True) -> bool:
74
- if a == "A" and b == "U": return True
75
- if a == "U" and b == "A": return True
76
- if a == "G" and b == "C": return True
77
- if a == "C" and b == "G": return True
78
- if allow_gu and ((a == "G" and b == "U") or (a == "U" and b == "G")):
79
  return True
80
  return False
81
 
82
- def _precompute_can_open(seq: str, min_loop_len: int, allow_gu: bool):
83
- """
84
- can_open[i] = there exists j >= i + min_loop_len + 1 with pair(seq[i], seq[j])
85
- """
86
- n = len(seq)
87
- can_open = [False] * n
88
- # For speed, pre-index future positions by base
89
- pos_by_base = {"A": [], "U": [], "G": [], "C": []}
90
- for idx in range(n-1, -1, -1):
91
- base = seq[idx]
92
- # update future lists (right side of idx)
93
- if idx + 1 < n:
94
- pos_by_base[seq[idx+1]].append(idx+1)
95
- # Check if any future partner exists with min loop
96
- min_j = idx + min_loop_len + 1
97
- ok = False
98
- for b, lst in pos_by_base.items():
99
- if any(j >= min_j and _can_pair(base, b, allow_gu) for j in lst):
100
- ok = True
101
  break
102
- can_open[idx] = ok
103
- return can_open
104
-
105
- # --------- Top-p sampling ----------
106
- def _top_p_sample_from_logits(logits: torch.Tensor, top_p: float, temperature: float) -> torch.Tensor:
107
- if temperature <= 0:
108
- # fall back to greedy
109
- return torch.argmax(logits, dim=-1)
110
- logits = logits / temperature
111
- probs = torch.softmax(logits, dim=-1)
112
-
113
- # sort
114
- sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
115
- cumulative = torch.cumsum(sorted_probs, dim=-1)
116
-
117
- # mask tokens beyond top_p
118
- cutoff = (cumulative > top_p).float()
119
- # ensure at least one token kept
120
- cutoff[..., 0] = 0.0
121
- keep = 1.0 - cutoff
122
-
123
- filtered_probs = sorted_probs * keep
124
- # renormalize
125
- filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
126
 
127
- # sample
128
- next_sorted = torch.multinomial(filtered_probs, num_samples=1)
129
- next_ids = sorted_idx.gather(-1, next_sorted)
130
- return next_ids.squeeze(-1)
131
-
132
- # --- Finite-State logits processor with RNA constraints ---
133
  class BalancedParenProcessor(LogitsProcessor):
134
- """
135
- Keeps output a valid dot-bracket string length N and adds biases:
136
- - Balance constraints (no ')' at depth 0; enough room to close)
137
- - Force close-all when remaining == depth
138
- - Forbid '(' if no feasible partner ahead (RNA pairing + min loop)
139
- - Dot bias (+bias to dot logit)
140
- - Penalize long runs of parentheses
141
- """
142
- def __init__(self, lp_id: int, rp_id: int, dot_id: int, total_len: int,
143
- can_open, paren_run_window=5, paren_run_penalty=0.5, dot_bias=0.8):
144
- self.lp_id = int(lp_id)
145
- self.rp_id = int(rp_id)
146
- self.dot_id = int(dot_id)
147
- self.total_len = int(total_len)
148
-
149
  self.step = 0
150
  self.depth = 0
151
- self.recent = deque(maxlen=paren_run_window)
152
- self.paren_run_penalty = float(paren_run_penalty)
153
- self.dot_bias = float(dot_bias)
154
- self.can_open = can_open # list[bool] length N
155
-
156
- def update_with_chosen(self, token_id: int):
157
- # called after each step by the generator
158
- self.recent.append(token_id)
159
- if token_id == self.lp_id:
160
- self.depth += 1
161
- elif token_id == self.rp_id:
162
- self.depth = max(0, self.depth - 1)
163
- self.step += 1
164
-
165
- def _recent_is_paren_run(self):
166
- if not self.recent:
167
- return False
168
- return all(t in (self.lp_id, self.rp_id) for t in self.recent)
169
 
170
  def __call__(self, input_ids, scores):
171
- # scores: (1, vocab)
172
- mask = torch.full_like(scores, float("-inf"))
173
-
174
- remaining = self.total_len - self.step
175
- allowed = []
176
-
177
- must_close_all = (remaining == self.depth and self.depth > 0)
178
- pos = self.step
179
-
180
- if must_close_all:
181
- allowed = [self.rp_id]
182
  else:
183
- # Allow ')' only if inside a stem
184
- if self.depth > 0:
185
  allowed.append(self.rp_id)
186
-
187
- # Allow '(' only if (a) room to close by end and (b) feasible partner ahead
188
- if remaining - 1 > self.depth and pos < self.total_len and self.can_open[pos]:
189
  allowed.append(self.lp_id)
190
-
191
- # '.' generally allowed
192
  allowed.append(self.dot_id)
193
-
194
- mask[:, allowed] = 0.0
195
- scores = scores + mask
196
-
197
- # add dot bias
198
- scores[..., self.dot_id] = scores[..., self.dot_id] + self.dot_bias
199
-
200
- # penalize long paren run
201
- if self._recent_is_paren_run():
202
- scores[..., self.lp_id] = scores[..., self.lp_id] - self.paren_run_penalty
203
- scores[..., self.rp_id] = scores[..., self.rp_id] - self.paren_run_penalty
204
-
205
  return scores
206
 
207
- # --- Generate exactly n chars using constrained decoding + sampling ---
208
- def _generate_db(seq: str) -> str:
209
- tokenizer, model, device = _load_model_and_tokenizer()
210
- n = len(seq)
211
- prompt = _make_prompt(seq)
212
-
213
- lp_id = _char_token_id(tokenizer, "(")
214
- rp_id = _char_token_id(tokenizer, ")")
215
- dot_id = _char_token_id(tokenizer, ".")
216
-
217
- can_open = _precompute_can_open(seq, MIN_LOOP_LEN, ALLOW_GU)
218
-
219
- processor = BalancedParenProcessor(
220
- lp_id, rp_id, dot_id, n,
221
- can_open=can_open,
222
- paren_run_window=PAREN_RUN_WINDOW,
223
- paren_run_penalty=PAREN_RUN_PENALTY,
224
- dot_bias=DOT_BIAS,
225
- )
226
- processors = LogitsProcessorList([processor])
227
-
228
- with torch.inference_mode():
229
- inputs = tokenizer(prompt, return_tensors="pt").to(device)
230
-
231
- generated = []
232
- cur_input = inputs["input_ids"]
233
- for _ in range(n):
234
- outputs = model(cur_input)
235
- logits = outputs.logits[:, -1, :] # (1, vocab)
236
- for p in processors:
237
- logits = p(cur_input, logits)
238
-
239
- next_id = _top_p_sample_from_logits(logits, TOP_P, TEMPERATURE).unsqueeze(0)
240
- token_id = next_id.item()
241
- generated.append(token_id)
242
-
243
- processor.update_with_chosen(token_id)
244
-
245
- cur_input = torch.cat([cur_input, next_id.unsqueeze(0)], dim=1)
246
-
247
- text = tokenizer.decode(generated, skip_special_tokens=True)
248
- db = "".join(c for c in text if c in "().")[:n]
249
- if len(db) != n:
250
- db = (db + "." * n)[:n]
 
 
 
251
  return db
252
 
253
- # --- Dot-bracket -> structural elements ---
254
- def dotbracket_to_structural(dot_str: str) -> str:
255
- if not dot_str:
256
- return "<start><external_loop><end>"
257
- res = ["<start>"]
258
- depth = 0
259
- i, n = 0, len(dot_str)
260
-
261
- def append_once(tag):
262
- if res[-1] != tag:
263
- res.append(tag)
264
-
265
- while i < n:
266
- c = dot_str[i]
267
- if c == '.':
268
- j = i
269
- while j < n and dot_str[j] == '.':
270
- j += 1
271
- next_char = dot_str[j] if j < n else None
272
- label = "<external_loop>" if depth == 0 else ("<hairpin>" if next_char == ')' else "<internal_loop>")
273
- append_once(label)
274
- i = j
275
- continue
276
- if c == '(':
277
- append_once("<stem>")
278
- depth += 1
279
- else: # ')'
280
- append_once("<stem>")
281
- depth = max(depth - 1, 0)
282
- i += 1
283
-
284
  res.append("<end>")
285
  return "".join(res)
286
 
287
- # --- Gradio handler ---
288
- def predict(seq: str):
289
- seq = (seq or "").strip().upper()
290
- if not seq or not set(seq) <= {"A","U","C","G"}:
291
  return "Please enter an RNA sequence (A/U/C/G)."
292
-
293
- db = _generate_db(seq)
294
  return dotbracket_to_structural(db)
295
 
296
- # UI with a few knobs exposed
297
- with gr.Blocks(title="RNA Structure Predictor") as demo:
298
- gr.Markdown("### RNA Structure Predictor\nOutputs structural-element notation: `<start>`, `<stem>`, `<hairpin>`, `<internal_loop>`, `<external_loop>`, `<end>`.")
299
- with gr.Row():
300
- seq_in = gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)", value="GGGAAUCC")
301
- out = gr.Textbox(lines=6, label="Predicted Structural Elements")
302
- with gr.Row():
303
- t = gr.Slider(0.1, 1.5, value=TEMPERATURE, step=0.05, label="Temperature")
304
- p = gr.Slider(0.5, 1.0, value=TOP_P, step=0.01, label="Top-p")
305
- dbias = gr.Slider(0.0, 2.0, value=DOT_BIAS, step=0.05, label="Dot bias (+logit)")
306
- looplen = gr.Slider(0, 5, value=MIN_LOOP_LEN, step=1, label="Min loop length")
307
- wobble = gr.Checkbox(value=ALLOW_GU, label="Allow GU wobble")
308
- btn = gr.Button("Submit", variant="primary")
309
-
310
- def _predict_with_knobs(seq, temperature, topp, dot_bias, min_loop, allow_gu):
311
- global TEMPERATURE, TOP_P, DOT_BIAS, MIN_LOOP_LEN, ALLOW_GU
312
- TEMPERATURE = float(temperature)
313
- TOP_P = float(topp)
314
- DOT_BIAS = float(dot_bias)
315
- MIN_LOOP_LEN = int(min_loop)
316
- # this affects precompute_can_open on next call
317
- ALLOW_GU = bool(allow_gu)
318
- return predict(seq)
319
-
320
- btn.click(
321
- _predict_with_knobs,
322
- inputs=[seq_in, t, p, dbias, looplen, wobble],
323
- outputs=[out],
324
- )
325
 
326
- if __name__ == "__main__":
327
  demo.launch()
 
1
  import os
 
 
2
  import torch
3
  import gradio as gr
4
  from functools import lru_cache
 
5
  from transformers import (
6
  AutoTokenizer,
7
  AutoModelForCausalLM,
 
9
  LogitsProcessorList,
10
  )
11
 
 
12
  MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  @lru_cache(maxsize=1)
15
  def _load_model_and_tokenizer():
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
  model.eval()
24
  return tokenizer, model, device
25
 
26
+ # --- Utility helpers ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  def _char_token_id(tokenizer, ch: str) -> int:
28
  ids = tokenizer.encode(ch, add_special_tokens=False)
29
  for tid in ids:
 
32
  for tid in range(len(tokenizer)):
33
  if tokenizer.decode([tid]) == ch:
34
  return tid
35
+ raise ValueError(f"Could not find token id for {ch}")
36
 
37
+ def _can_pair(a, b, allow_gu=True):
38
+ if (a,b) in [("A","U"),("U","A"),("G","C"),("C","G")]:
39
+ return True
40
+ if allow_gu and (a,b) in [("G","U"),("U","G")]:
 
 
 
41
  return True
42
  return False
43
 
44
+ def _precompute_can_open(seq, min_loop=3, allow_gu=True):
45
+ n=len(seq)
46
+ can=[False]*n
47
+ for i in range(n):
48
+ for j in range(i+min_loop+1,n):
49
+ if _can_pair(seq[i],seq[j],allow_gu):
50
+ can[i]=True
 
 
 
 
 
 
 
 
 
 
 
 
51
  break
52
+ return can
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ # --- constrained processor ---
 
 
 
 
 
55
  class BalancedParenProcessor(LogitsProcessor):
56
+ def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
57
+ dot_bias=0.8, paren_penalty=0.5, window=5):
58
+ self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
59
+ self.total_len = total_len
 
 
 
 
 
 
 
 
 
 
 
60
  self.step = 0
61
  self.depth = 0
62
+ self.history=[]
63
+ self.can_open = can_open
64
+ self.dot_bias=dot_bias
65
+ self.paren_penalty=paren_penalty
66
+ self.window=window
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  def __call__(self, input_ids, scores):
69
+ mask=torch.full_like(scores,float("-inf"))
70
+ remaining=self.total_len-self.step
71
+ allowed=[]
72
+ must_close=(remaining==self.depth and self.depth>0)
73
+ pos=self.step
74
+ if must_close:
75
+ allowed=[self.rp_id]
 
 
 
 
76
  else:
77
+ if self.depth>0:
 
78
  allowed.append(self.rp_id)
79
+ if remaining-1>self.depth and pos<len(self.can_open) and self.can_open[pos]:
 
 
80
  allowed.append(self.lp_id)
 
 
81
  allowed.append(self.dot_id)
82
+ mask[:,allowed]=0.0
83
+ scores=scores+mask
84
+ scores[:,self.dot_id]+=self.dot_bias
85
+ if len(self.history)>=self.window and all(t in (self.lp_id,self.rp_id) for t in self.history[-self.window:]):
86
+ scores[:,self.lp_id]-=self.paren_penalty
87
+ scores[:,self.rp_id]-=self.paren_penalty
 
 
 
 
 
 
88
  return scores
89
 
90
+ def update(self, tok):
91
+ if tok==self.lp_id:
92
+ self.depth+=1
93
+ elif tok==self.rp_id:
94
+ self.depth=max(0,self.depth-1)
95
+ self.history.append(tok)
96
+ self.step+=1
97
+
98
+ def _top_p_sample(logits, top_p=0.9, temperature=0.8):
99
+ logits=logits/temperature
100
+ probs=torch.softmax(logits,dim=-1)
101
+ sorted_probs,sorted_idx=torch.sort(probs,descending=True)
102
+ cum=torch.cumsum(sorted_probs,dim=-1)
103
+ mask=cum>top_p
104
+ mask[...,0]=False
105
+ sorted_probs[mask]=0
106
+ sorted_probs/=sorted_probs.sum(dim=-1,keepdim=True)
107
+ idx=torch.multinomial(sorted_probs,1)
108
+ return sorted_idx.gather(-1,idx).squeeze(-1)
109
+
110
+ # --- generator ---
111
+ def _generate_db(seq):
112
+ tok,model,device=_load_model_and_tokenizer()
113
+ n=len(seq)
114
+ prompt=f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
115
+ lp=_char_token_id(tok,"("); rp=_char_token_id(tok,")"); dot=_char_token_id(tok,".")
116
+ can=_precompute_can_open(seq)
117
+ proc=BalancedParenProcessor(lp,rp,dot,n,can)
118
+ procs=LogitsProcessorList([proc])
119
+ inputs=tok(prompt,return_tensors="pt").to(device)
120
+ cur=inputs["input_ids"]
121
+ generated=[]
122
+ for _ in range(n):
123
+ out=model(cur)
124
+ logits=out.logits[:,-1,:]
125
+ for p in procs:
126
+ logits=p(cur,logits)
127
+ next_id=_top_p_sample(logits,0.9,0.8)
128
+ next_id=next_id.to(device)
129
+ tokid=next_id.item()
130
+ generated.append(tokid)
131
+ proc.update(tokid)
132
+ cur=torch.cat([cur,next_id.view(1,1)],dim=1)
133
+ text=tok.decode(generated,skip_special_tokens=True)
134
+ db="".join(c for c in text if c in "().")[:n]
135
+ if len(db)!=n:
136
+ db=(db+"."*n)[:n]
137
  return db
138
 
139
+ # --- structural element translation ---
140
+ def dotbracket_to_structural(dot_str):
141
+ if not dot_str: return "<start><external_loop><end>"
142
+ res=["<start>"];depth=0;i=0;n=len(dot_str)
143
+ def add(tag):
144
+ if res[-1]!=tag:res.append(tag)
145
+ while i<n:
146
+ c=dot_str[i]
147
+ if c==".":
148
+ j=i
149
+ while j<n and dot_str[j]==".":
150
+ j+=1
151
+ nextc=dot_str[j] if j<n else None
152
+ tag="<external_loop>" if depth==0 else ("<hairpin>" if nextc==")" else "<internal_loop>")
153
+ add(tag);i=j;continue
154
+ if c=="(":
155
+ add("<stem>");depth+=1
156
+ else:
157
+ add("<stem>");depth=max(0,depth-1)
158
+ i+=1
 
 
 
 
 
 
 
 
 
 
 
159
  res.append("<end>")
160
  return "".join(res)
161
 
162
+ # --- Gradio wrapper ---
163
+ def predict(seq):
164
+ seq=(seq or "").strip().upper()
165
+ if not seq or not set(seq)<={"A","U","C","G"}:
166
  return "Please enter an RNA sequence (A/U/C/G)."
167
+ db=_generate_db(seq)
 
168
  return dotbracket_to_structural(db)
169
 
170
+ demo=gr.Interface(
171
+ fn=predict,
172
+ inputs=gr.Textbox(lines=4,label="RNA Sequence (A/U/C/G)",value="GGGAAUCC"),
173
+ outputs=gr.Textbox(lines=6,label="Predicted Structural Elements"),
174
+ title="RNA Structure Predictor",
175
+ description="Outputs <start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>."
176
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ if __name__=="__main__":
179
  demo.launch()