1-1-3-8 commited on
Commit
86dfbbc
Β·
verified Β·
1 Parent(s): c94bb03

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +106 -62
app.py CHANGED
@@ -9,8 +9,10 @@ from transformers import (
9
  LogitsProcessorList,
10
  )
11
 
 
12
  MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
13
 
 
14
  @lru_cache(maxsize=1)
15
  def _load_model_and_tokenizer():
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -21,58 +23,63 @@ def _load_model_and_tokenizer():
21
  device_map="auto" if device == "cuda" else None,
22
  )
23
  model.eval()
 
 
24
  return tokenizer, model, device
25
 
26
- # --- Utility helpers ---
27
  def _char_token_id(tokenizer, ch: str) -> int:
 
28
  ids = tokenizer.encode(ch, add_special_tokens=False)
29
  for tid in ids:
30
  if tokenizer.decode([tid]) == ch:
31
  return tid
 
32
  for tid in range(len(tokenizer)):
33
  if tokenizer.decode([tid]) == ch:
34
  return tid
35
- raise ValueError(f"Could not find token id for {ch}")
36
 
37
  def _can_pair(a, b, allow_gu=True):
38
- if (a,b) in [("A","U"),("U","A"),("G","C"),("C","G")]:
39
  return True
40
- if allow_gu and (a,b) in [("G","U"),("U","G")]:
41
  return True
42
  return False
43
 
44
  def _precompute_can_open(seq, min_loop=3, allow_gu=True):
45
- n=len(seq)
46
- can=[False]*n
47
  for i in range(n):
48
- for j in range(i+min_loop+1,n):
49
- if _can_pair(seq[i],seq[j],allow_gu):
50
- can[i]=True
51
  break
52
  return can
53
 
54
- # --- constrained processor ---
55
- # --- constrained processor ---
56
  class BalancedParenProcessor(LogitsProcessor):
 
 
 
 
57
  def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
58
  dot_bias=0.0, paren_penalty=0.0, window=5):
59
  self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
60
  self.total_len = total_len
61
  self.step = 0
62
  self.depth = 0
63
- self.history=[]
64
  self.can_open = can_open
65
- self.dot_bias=dot_bias
66
- self.paren_penalty=paren_penalty
67
- self.window=window
68
 
69
  def __call__(self, input_ids, scores):
70
- # restrict to only three tokens
71
  mask = torch.full_like(scores, float("-inf"))
72
  remaining = self.total_len - self.step
73
  allowed = []
74
-
75
- # If we must close to avoid running out of room, force )
76
  must_close = (remaining == self.depth and self.depth > 0)
77
  pos = self.step
78
 
@@ -81,22 +88,18 @@ class BalancedParenProcessor(LogitsProcessor):
81
  else:
82
  if self.depth > 0:
83
  allowed.append(self.rp_id)
84
-
85
- # allow opening if there will still be room to close later
86
- # (be a bit less strict than remaining-1 > depth to encourage stems)
87
  if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]:
88
  allowed.append(self.lp_id)
89
-
90
  allowed.append(self.dot_id)
91
 
92
  mask[:, allowed] = 0.0
93
  scores = scores + mask
94
 
95
- # (no dot boost by default)
96
  if self.dot_bias != 0.0:
97
  scores[:, self.dot_id] += self.dot_bias
98
 
99
- # optional mild anti-run regularizer
100
  if self.paren_penalty and len(self.history) >= self.window and all(
101
  t in (self.lp_id, self.rp_id) for t in self.history[-self.window:]
102
  ):
@@ -105,19 +108,49 @@ class BalancedParenProcessor(LogitsProcessor):
105
 
106
  return scores
107
 
108
- # --- generator ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
  def _generate_db(seq):
110
  tok, model, device = _load_model_and_tokenizer()
111
  n = len(seq)
112
  prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
113
- lp = _char_token_id(tok, "("); rp = _char_token_id(tok, ")"); dot = _char_token_id(tok, ".")
114
- can = _precompute_can_open(seq, min_loop=3) # try 2 if you still get few stems
 
 
 
 
 
 
 
115
  proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0)
116
  procs = LogitsProcessorList([proc])
117
- inputs = tok(prompt, return_tensors="pt").to(device)
 
 
 
118
  cur = inputs["input_ids"]
119
- generated = []
120
 
 
121
  with torch.no_grad():
122
  for _ in range(n):
123
  out = model(cur)
@@ -125,11 +158,11 @@ def _generate_db(seq):
125
  for p in procs:
126
  logits = p(cur, logits)
127
  next_id = _top_p_sample(logits, top_p=0.9, temperature=0.8)
128
- next_id = next_id.to(device)
129
  tokid = next_id.item()
130
  generated.append(tokid)
131
  proc.update(tokid)
132
- cur = torch.cat([cur, next_id.view(1, 1)], dim=1)
 
133
 
134
  text = tok.decode(generated, skip_special_tokens=True)
135
  db = "".join(c for c in text if c in "().")[:n]
@@ -137,45 +170,56 @@ def _generate_db(seq):
137
  db = (db + "." * n)[:n]
138
  return db
139
 
140
-
141
- # --- structural element translation ---
142
  def dotbracket_to_structural(dot_str):
143
- if not dot_str: return "<start><external_loop><end>"
144
- res=["<start>"];depth=0;i=0;n=len(dot_str)
145
- def add(tag):
146
- if res[-1]!=tag:res.append(tag)
147
- while i<n:
148
- c=dot_str[i]
149
- if c==".":
150
- j=i
151
- while j<n and dot_str[j]==".":
152
- j+=1
153
- nextc=dot_str[j] if j<n else None
154
- tag="<external_loop>" if depth==0 else ("<hairpin>" if nextc==")" else "<internal_loop>")
155
- add(tag);i=j;continue
156
- if c=="(":
157
- add("<stem>");depth+=1
158
- else:
159
- add("<stem>");depth=max(0,depth-1)
160
- i+=1
 
 
 
 
 
161
  res.append("<end>")
162
  return "".join(res)
163
 
164
- # --- Gradio wrapper ---
165
- def predict(seq):
166
- seq=(seq or "").strip().upper()
167
- if not seq or not set(seq)<={"A","U","C","G"}:
168
- return "Please enter an RNA sequence (A/U/C/G)."
169
- db=_generate_db(seq)
170
- return dotbracket_to_structural(db)
171
 
172
- demo=gr.Interface(
 
 
 
 
 
 
 
 
 
 
 
 
173
  fn=predict,
174
- inputs=gr.Textbox(lines=4,label="RNA Sequence (A/U/C/G)",value="GGGAAUCC"),
175
- outputs=gr.Textbox(lines=6,label="Predicted Structural Elements"),
176
  title="RNA Structure Predictor",
177
  description="Outputs <start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>."
178
  )
179
 
180
- if __name__=="__main__":
181
  demo.launch()
 
9
  LogitsProcessorList,
10
  )
11
 
12
+ # ── Config ──────────────────────────────────────────────────────────────────────
13
  MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
14
 
15
+ # ── Model loading ───────────────────────────────────────────────────────────────
16
  @lru_cache(maxsize=1)
17
  def _load_model_and_tokenizer():
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
23
  device_map="auto" if device == "cuda" else None,
24
  )
25
  model.eval()
26
+ if device != "cuda":
27
+ model.to(device)
28
  return tokenizer, model, device
29
 
30
+ # ── Utility helpers ─────────────────────────────────────────────────────────────
31
  def _char_token_id(tokenizer, ch: str) -> int:
32
+ # Prefer an exact single-char token if it exists
33
  ids = tokenizer.encode(ch, add_special_tokens=False)
34
  for tid in ids:
35
  if tokenizer.decode([tid]) == ch:
36
  return tid
37
+ # Fallback: scan vocab for an exact decode match
38
  for tid in range(len(tokenizer)):
39
  if tokenizer.decode([tid]) == ch:
40
  return tid
41
+ raise ValueError(f"Could not find token id for {ch!r}")
42
 
43
  def _can_pair(a, b, allow_gu=True):
44
+ if (a, b) in [("A","U"),("U","A"),("G","C"),("C","G")]:
45
  return True
46
+ if allow_gu and (a, b) in [("G","U"),("U","G")]:
47
  return True
48
  return False
49
 
50
  def _precompute_can_open(seq, min_loop=3, allow_gu=True):
51
+ n = len(seq)
52
+ can = [False] * n
53
  for i in range(n):
54
+ for j in range(i + min_loop + 1, n):
55
+ if _can_pair(seq[i], seq[j], allow_gu):
56
+ can[i] = True
57
  break
58
  return can
59
 
60
+ # ── Constrained processor ───────────────────────────────────────────────────────
 
61
  class BalancedParenProcessor(LogitsProcessor):
62
+ """
63
+ Restricts next token to one of: '(', ')' or '.', while maintaining balance
64
+ and leaving room to close opened stems. No dot bias by default.
65
+ """
66
  def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
67
  dot_bias=0.0, paren_penalty=0.0, window=5):
68
  self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
69
  self.total_len = total_len
70
  self.step = 0
71
  self.depth = 0
72
+ self.history = []
73
  self.can_open = can_open
74
+ self.dot_bias = dot_bias
75
+ self.paren_penalty = paren_penalty
76
+ self.window = window
77
 
78
  def __call__(self, input_ids, scores):
79
+ # Mask everything except allowed tokens
80
  mask = torch.full_like(scores, float("-inf"))
81
  remaining = self.total_len - self.step
82
  allowed = []
 
 
83
  must_close = (remaining == self.depth and self.depth > 0)
84
  pos = self.step
85
 
 
88
  else:
89
  if self.depth > 0:
90
  allowed.append(self.rp_id)
91
+ # Allow opening if there will still be room to close later
 
 
92
  if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]:
93
  allowed.append(self.lp_id)
 
94
  allowed.append(self.dot_id)
95
 
96
  mask[:, allowed] = 0.0
97
  scores = scores + mask
98
 
 
99
  if self.dot_bias != 0.0:
100
  scores[:, self.dot_id] += self.dot_bias
101
 
102
+ # Optional mild anti-run for long paren streaks
103
  if self.paren_penalty and len(self.history) >= self.window and all(
104
  t in (self.lp_id, self.rp_id) for t in self.history[-self.window:]
105
  ):
 
108
 
109
  return scores
110
 
111
+ def update(self, tok):
112
+ if tok == self.lp_id:
113
+ self.depth += 1
114
+ elif tok == self.rp_id:
115
+ self.depth = max(0, self.depth - 1)
116
+ self.history.append(tok)
117
+ self.step += 1
118
+
119
+ def _top_p_sample(logits, top_p=0.9, temperature=0.8):
120
+ logits = logits / temperature
121
+ probs = torch.softmax(logits, dim=-1)
122
+ sorted_probs, sorted_idx = torch.sort(probs, descending=True)
123
+ cumsum = torch.cumsum(sorted_probs, dim=-1)
124
+ mask = cumsum > top_p
125
+ mask[..., 0] = False
126
+ sorted_probs[mask] = 0
127
+ sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
128
+ idx = torch.multinomial(sorted_probs, 1)
129
+ return sorted_idx.gather(-1, idx).squeeze(-1)
130
+
131
+ # ── Generator ───────────────────────────────────────────────────────────────────
132
  def _generate_db(seq):
133
  tok, model, device = _load_model_and_tokenizer()
134
  n = len(seq)
135
  prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
136
+
137
+ lp = _char_token_id(tok, "(")
138
+ rp = _char_token_id(tok, ")")
139
+ dot = _char_token_id(tok, ".")
140
+
141
+ # Helpful to verify once in logs
142
+ print("Token IDs:", {"(": lp, ")": rp, ".": dot})
143
+
144
+ can = _precompute_can_open(seq, min_loop=3, allow_gu=True)
145
  proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0)
146
  procs = LogitsProcessorList([proc])
147
+
148
+ inputs = tok(prompt, return_tensors="pt")
149
+ # Keep tensors on the same device as the model
150
+ inputs = {k: v.to(model.device) for k, v in inputs.items()}
151
  cur = inputs["input_ids"]
 
152
 
153
+ generated = []
154
  with torch.no_grad():
155
  for _ in range(n):
156
  out = model(cur)
 
158
  for p in procs:
159
  logits = p(cur, logits)
160
  next_id = _top_p_sample(logits, top_p=0.9, temperature=0.8)
 
161
  tokid = next_id.item()
162
  generated.append(tokid)
163
  proc.update(tokid)
164
+ # Make sure we append on the SAME device as cur/model
165
+ cur = torch.cat([cur, next_id.view(1, 1).to(cur.device)], dim=1)
166
 
167
  text = tok.decode(generated, skip_special_tokens=True)
168
  db = "".join(c for c in text if c in "().")[:n]
 
170
  db = (db + "." * n)[:n]
171
  return db
172
 
173
+ # ── Structural element translation ──────────────────────────────────────────────
 
174
  def dotbracket_to_structural(dot_str):
175
+ if not dot_str or not isinstance(dot_str, str):
176
+ return "<start><external_loop><end>"
177
+ res = ["<start>"]; depth = 0; i = 0; n = len(dot_str)
178
+
179
+ def add(tag):
180
+ if res[-1] != tag:
181
+ res.append(tag)
182
+
183
+ while i < n:
184
+ c = dot_str[i]
185
+ if c == ".":
186
+ j = i
187
+ while j < n and dot_str[j] == ".":
188
+ j += 1
189
+ nextc = dot_str[j] if j < n else None
190
+ tag = "<external_loop>" if depth == 0 else ("<hairpin>" if nextc == ")" else "<internal_loop>")
191
+ add(tag); i = j; continue
192
+ if c == "(":
193
+ add("<stem>"); depth += 1
194
+ else: # ')'
195
+ add("<stem>"); depth = max(0, depth - 1)
196
+ i += 1
197
+
198
  res.append("<end>")
199
  return "".join(res)
200
 
201
+ # ── Gradio wrapper ──────────────────────────────────────────────────────────────
202
+ import traceback
 
 
 
 
 
203
 
204
+ def predict(seq):
205
+ try:
206
+ seq = (seq or "").strip().upper()
207
+ if not seq or not set(seq) <= {"A", "U", "C", "G"}:
208
+ return "Please enter an RNA sequence (A/U/C/G)."
209
+ db = _generate_db(seq)
210
+ return dotbracket_to_structural(db)
211
+ except Exception as e:
212
+ # Print full traceback to Space logs and show a concise error in UI
213
+ traceback.print_exc()
214
+ return f"Error: {type(e).__name__}: {e}"
215
+
216
+ demo = gr.Interface(
217
  fn=predict,
218
+ inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)", value="GGGAAUCC"),
219
+ outputs=gr.Textbox(lines=6, label="Predicted Structural Elements"),
220
  title="RNA Structure Predictor",
221
  description="Outputs <start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>."
222
  )
223
 
224
+ if __name__ == "__main__":
225
  demo.launch()