1-1-3-8 commited on
Commit
0a649c0
·
verified ·
1 Parent(s): ec37350

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -60
app.py CHANGED
@@ -3,7 +3,12 @@ import re
3
  import torch
4
  import gradio as gr
5
  from functools import lru_cache
6
- from transformers import AutoTokenizer, AutoModelForCausalLM, StoppingCriteria, StoppingCriteriaList
 
 
 
 
 
7
 
8
  MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
9
 
@@ -19,70 +24,123 @@ def _load_model_and_tokenizer():
19
  model.eval()
20
  return tokenizer, model, device
21
 
 
 
 
 
 
 
 
 
 
 
22
  def _make_prompt(seq: str) -> str:
23
  n = len(seq)
24
  return (
25
- f"RNA: {seq}\n"
26
- f"Output ONLY the RNA secondary structure in dot-bracket notation, exactly {n} characters long, "
27
- f"using only '(' ')' and '.'.\n"
28
- f"Dot-bracket:"
29
  )
30
 
31
- def _generate(prompt: str, max_new_tokens: int = 256):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  tokenizer, model, device = _load_model_and_tokenizer()
 
 
33
 
34
- class StopOnNewline(StoppingCriteria):
35
- def __init__(self, newline_id: int): self.newline_id = newline_id
36
- def __call__(self, input_ids, scores, **kwargs):
37
- return input_ids[0, -1].item() == self.newline_id
 
 
 
 
 
 
 
 
38
 
39
- nl_id = tokenizer.encode("\n", add_special_tokens=False)[0]
40
  with torch.inference_mode():
41
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
42
- outputs = model.generate(
43
  **inputs,
44
- max_new_tokens=max_new_tokens,
45
- temperature=0.0,
46
  do_sample=False,
 
 
47
  pad_token_id=tokenizer.eos_token_id,
48
  eos_token_id=tokenizer.eos_token_id,
49
- stopping_criteria=StoppingCriteriaList([StopOnNewline(nl_id)]),
50
  )
51
- gen_tokens = outputs[0][inputs["input_ids"].shape[-1]:]
52
- return tokenizer.decode(gen_tokens, skip_special_tokens=True)
53
-
54
- def _extract_dotbracket_strict(text: str, length: int):
55
- text = text.strip()
56
- candidates = []
57
-
58
- # Filter all dot-bracket-like substrings
59
- for line in text.splitlines():
60
- line = line.strip()
61
- cand = "".join(c for c in line if c in "().")
62
- if cand:
63
- candidates.append(cand)
64
-
65
- # Choose first one with exact or closest match
66
- for cand in candidates:
67
- if len(cand) == length:
68
- return cand
69
- if candidates:
70
- # fallback: pick longest valid segment if none matches perfectly
71
- return max(candidates, key=len)
72
- return None
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def dotbracket_to_structural(dot_str: str) -> str:
75
  if not dot_str:
76
  return "<start><external_loop><end>"
77
-
78
- n = len(dot_str)
79
  res = ["<start>"]
80
  depth = 0
81
- i = 0
82
 
83
- def append_once(tag: str):
84
- if res[-1] != tag:
85
- res.append(tag)
86
 
87
  while i < n:
88
  c = dot_str[i]
@@ -91,17 +149,14 @@ def dotbracket_to_structural(dot_str: str) -> str:
91
  while j < n and dot_str[j] == '.':
92
  j += 1
93
  next_char = dot_str[j] if j < n else None
94
- if depth == 0:
95
- label = "<external_loop>"
96
- else:
97
- label = "<hairpin>" if next_char == ')' else "<internal_loop>"
98
  append_once(label)
99
  i = j
100
  continue
101
  if c == '(':
102
  append_once("<stem>")
103
  depth += 1
104
- elif c == ')':
105
  append_once("<stem>")
106
  depth = max(depth - 1, 0)
107
  i += 1
@@ -109,23 +164,18 @@ def dotbracket_to_structural(dot_str: str) -> str:
109
  res.append("<end>")
110
  return "".join(res)
111
 
 
112
  def predict(seq: str):
113
  seq = (seq or "").strip().upper()
114
- if not seq or not set(seq) <= {"A", "U", "C", "G"}:
115
  return "Please enter an RNA sequence (A/U/C/G)."
116
 
117
- n = len(seq)
118
- prompt = _make_prompt(seq)
119
- text = _generate(prompt, max_new_tokens=n + 8)
120
 
121
- db = _extract_dotbracket_strict(text, n)
 
 
122
 
123
- # fallback if model gave nothing meaningful
124
- if not db:
125
- return "<start><external_loop><end>"
126
-
127
- # If output shorter or longer, pad/truncate safely
128
- db = (db[:n] + "." * n)[:n]
129
  return dotbracket_to_structural(db)
130
 
131
  demo = gr.Interface(
@@ -138,4 +188,3 @@ demo = gr.Interface(
138
 
139
  if __name__ == "__main__":
140
  demo.launch()
141
-
 
3
  import torch
4
  import gradio as gr
5
  from functools import lru_cache
6
+ from transformers import (
7
+ AutoTokenizer,
8
+ AutoModelForCausalLM,
9
+ LogitsProcessor,
10
+ LogitsProcessorList,
11
+ )
12
 
13
  MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
14
 
 
24
  model.eval()
25
  return tokenizer, model, device
26
 
27
+ # ---------- Prompt (few-shot to reduce "all dots") ----------
28
+ FEWSHOT = """Return ONLY the dot-bracket structure as one line with the same length as RNA.
29
+
30
+ RNA: GCGCGAAAACGCGC
31
+ Dot-bracket: (((((....)))))
32
+
33
+ RNA: GGGAAAUCCCU
34
+ Dot-bracket: (((...)))
35
+ """
36
+
37
  def _make_prompt(seq: str) -> str:
38
  n = len(seq)
39
  return (
40
+ FEWSHOT
41
+ + f"\nRNA: {seq}\n"
42
+ + f"Dot-bracket (exactly {n} characters using only '(' ')' '.'):"
 
43
  )
44
 
45
+ # ---------- Robust char→token id ----------
46
+ def _char_token_id(tokenizer, ch: str) -> int:
47
+ # Try simple path
48
+ ids = tokenizer.encode(ch, add_special_tokens=False)
49
+ if ids:
50
+ # prefer single-token mapping when it decodes back to the same char
51
+ for tid in ids:
52
+ if tokenizer.decode([tid]) == ch:
53
+ return tid
54
+ return ids[-1]
55
+ # Fallback: scan vocab for a token that decodes to ch
56
+ for tid in range(len(tokenizer)):
57
+ if tokenizer.decode([tid]) == ch:
58
+ return tid
59
+ raise ValueError(f"Could not find token id for char {ch!r}")
60
+
61
+ # ---------- Constrained + biased generation ----------
62
+ class AllowOnlyAndBias(LogitsProcessor):
63
+ def __init__(self, allowed_ids, bias_map):
64
+ self.allowed = torch.tensor(allowed_ids, dtype=torch.long)
65
+ self.bias_map = {int(k): float(v) for k, v in bias_map.items()}
66
+ def __call__(self, input_ids, scores):
67
+ # mask everything else
68
+ scores[:] = float("-inf")
69
+ scores[:, self.allowed] = 0.0
70
+ # add biases to steer away from '.' and toward parentheses
71
+ for tid, bias in self.bias_map.items():
72
+ scores[:, tid] += bias
73
+ return scores
74
+
75
+ def _generate_db(seq: str) -> str:
76
  tokenizer, model, device = _load_model_and_tokenizer()
77
+ n = len(seq)
78
+ prompt = _make_prompt(seq)
79
 
80
+ # get robust IDs for '(', ')', '.'
81
+ lp_id = _char_token_id(tokenizer, "(")
82
+ rp_id = _char_token_id(tokenizer, ")")
83
+ dot_id = _char_token_id(tokenizer, ".")
84
+
85
+ processors = LogitsProcessorList([
86
+ # encourage parentheses, discourage all-dots
87
+ AllowOnlyAndBias(
88
+ allowed_ids=[lp_id, rp_id, dot_id],
89
+ bias_map={lp_id: +1.2, rp_id: +1.2, dot_id: -0.8},
90
+ )
91
+ ])
92
 
 
93
  with torch.inference_mode():
94
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
95
+ out = model.generate(
96
  **inputs,
97
+ max_new_tokens=n,
98
+ min_new_tokens=n,
99
  do_sample=False,
100
+ temperature=0.0,
101
+ logits_processor=processors,
102
  pad_token_id=tokenizer.eos_token_id,
103
  eos_token_id=tokenizer.eos_token_id,
 
104
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
 
106
+ gen = out[0][inputs["input_ids"].shape[-1]:]
107
+ text = tokenizer.decode(gen, skip_special_tokens=True)
108
+ db = "".join(c for c in text if c in "().")[:n]
109
+ if len(db) != n:
110
+ db = (db + "." * n)[:n]
111
+ return db
112
+
113
+ # ---------- Simple heuristic fallback if model gives all dots ----------
114
+ def _is_complement(a, b):
115
+ return (a == "G" and b == "C") or (a == "C" and b == "G") or (a == "A" and b == "U") or (a == "U" and b == "A")
116
+
117
+ def naive_hairpin(seq: str, min_loop: int = 3) -> str:
118
+ n = len(seq)
119
+ db = ["." for _ in range(n)]
120
+ i, j = 0, n - 1
121
+ while i < j - min_loop:
122
+ if _is_complement(seq[i], seq[j]):
123
+ db[i], db[j] = "(", ")"
124
+ i += 1
125
+ j -= 1
126
+ else:
127
+ # move the weaker side inward to try to find a match
128
+ if seq[i] in "AU" and seq[j] in "GC":
129
+ i += 1
130
+ else:
131
+ j -= 1
132
+ return "".join(db)
133
+
134
+ # ---------- Dot-bracket → structural ----------
135
  def dotbracket_to_structural(dot_str: str) -> str:
136
  if not dot_str:
137
  return "<start><external_loop><end>"
 
 
138
  res = ["<start>"]
139
  depth = 0
140
+ i, n = 0, len(dot_str)
141
 
142
+ def append_once(tag):
143
+ if res[-1] != tag: res.append(tag)
 
144
 
145
  while i < n:
146
  c = dot_str[i]
 
149
  while j < n and dot_str[j] == '.':
150
  j += 1
151
  next_char = dot_str[j] if j < n else None
152
+ label = "<external_loop>" if depth == 0 else ("<hairpin>" if next_char == ')' else "<internal_loop>")
 
 
 
153
  append_once(label)
154
  i = j
155
  continue
156
  if c == '(':
157
  append_once("<stem>")
158
  depth += 1
159
+ else: # ')'
160
  append_once("<stem>")
161
  depth = max(depth - 1, 0)
162
  i += 1
 
164
  res.append("<end>")
165
  return "".join(res)
166
 
167
+ # ---------- UI fn ----------
168
  def predict(seq: str):
169
  seq = (seq or "").strip().upper()
170
+ if not seq or not set(seq) <= {"A","U","C","G"}:
171
  return "Please enter an RNA sequence (A/U/C/G)."
172
 
173
+ db = _generate_db(seq)
 
 
174
 
175
+ # if model still produced all dots, try a naive hairpin so you get stems
176
+ if db.count("(") + db.count(")") == 0:
177
+ db = naive_hairpin(seq)
178
 
 
 
 
 
 
 
179
  return dotbracket_to_structural(db)
180
 
181
  demo = gr.Interface(
 
188
 
189
  if __name__ == "__main__":
190
  demo.launch()