Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,8 +9,10 @@ from transformers import (
|
|
| 9 |
LogitsProcessorList,
|
| 10 |
)
|
| 11 |
|
|
|
|
| 12 |
MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
|
| 13 |
|
|
|
|
| 14 |
@lru_cache(maxsize=1)
|
| 15 |
def _load_model_and_tokenizer():
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -21,58 +23,63 @@ def _load_model_and_tokenizer():
|
|
| 21 |
device_map="auto" if device == "cuda" else None,
|
| 22 |
)
|
| 23 |
model.eval()
|
|
|
|
|
|
|
| 24 |
return tokenizer, model, device
|
| 25 |
|
| 26 |
-
#
|
| 27 |
def _char_token_id(tokenizer, ch: str) -> int:
|
|
|
|
| 28 |
ids = tokenizer.encode(ch, add_special_tokens=False)
|
| 29 |
for tid in ids:
|
| 30 |
if tokenizer.decode([tid]) == ch:
|
| 31 |
return tid
|
|
|
|
| 32 |
for tid in range(len(tokenizer)):
|
| 33 |
if tokenizer.decode([tid]) == ch:
|
| 34 |
return tid
|
| 35 |
-
raise ValueError(f"Could not find token id for {ch}")
|
| 36 |
|
| 37 |
def _can_pair(a, b, allow_gu=True):
|
| 38 |
-
if (a,b) in [("A","U"),("U","A"),("G","C"),("C","G")]:
|
| 39 |
return True
|
| 40 |
-
if allow_gu and (a,b) in [("G","U"),("U","G")]:
|
| 41 |
return True
|
| 42 |
return False
|
| 43 |
|
| 44 |
def _precompute_can_open(seq, min_loop=3, allow_gu=True):
|
| 45 |
-
n=len(seq)
|
| 46 |
-
can=[False]*n
|
| 47 |
for i in range(n):
|
| 48 |
-
for j in range(i+min_loop+1,n):
|
| 49 |
-
if _can_pair(seq[i],seq[j],allow_gu):
|
| 50 |
-
can[i]=True
|
| 51 |
break
|
| 52 |
return can
|
| 53 |
|
| 54 |
-
#
|
| 55 |
-
# --- constrained processor ---
|
| 56 |
class BalancedParenProcessor(LogitsProcessor):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
|
| 58 |
dot_bias=0.0, paren_penalty=0.0, window=5):
|
| 59 |
self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
|
| 60 |
self.total_len = total_len
|
| 61 |
self.step = 0
|
| 62 |
self.depth = 0
|
| 63 |
-
self.history=[]
|
| 64 |
self.can_open = can_open
|
| 65 |
-
self.dot_bias=dot_bias
|
| 66 |
-
self.paren_penalty=paren_penalty
|
| 67 |
-
self.window=window
|
| 68 |
|
| 69 |
def __call__(self, input_ids, scores):
|
| 70 |
-
#
|
| 71 |
mask = torch.full_like(scores, float("-inf"))
|
| 72 |
remaining = self.total_len - self.step
|
| 73 |
allowed = []
|
| 74 |
-
|
| 75 |
-
# If we must close to avoid running out of room, force )
|
| 76 |
must_close = (remaining == self.depth and self.depth > 0)
|
| 77 |
pos = self.step
|
| 78 |
|
|
@@ -81,22 +88,18 @@ class BalancedParenProcessor(LogitsProcessor):
|
|
| 81 |
else:
|
| 82 |
if self.depth > 0:
|
| 83 |
allowed.append(self.rp_id)
|
| 84 |
-
|
| 85 |
-
# allow opening if there will still be room to close later
|
| 86 |
-
# (be a bit less strict than remaining-1 > depth to encourage stems)
|
| 87 |
if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]:
|
| 88 |
allowed.append(self.lp_id)
|
| 89 |
-
|
| 90 |
allowed.append(self.dot_id)
|
| 91 |
|
| 92 |
mask[:, allowed] = 0.0
|
| 93 |
scores = scores + mask
|
| 94 |
|
| 95 |
-
# (no dot boost by default)
|
| 96 |
if self.dot_bias != 0.0:
|
| 97 |
scores[:, self.dot_id] += self.dot_bias
|
| 98 |
|
| 99 |
-
#
|
| 100 |
if self.paren_penalty and len(self.history) >= self.window and all(
|
| 101 |
t in (self.lp_id, self.rp_id) for t in self.history[-self.window:]
|
| 102 |
):
|
|
@@ -105,19 +108,49 @@ class BalancedParenProcessor(LogitsProcessor):
|
|
| 105 |
|
| 106 |
return scores
|
| 107 |
|
| 108 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
def _generate_db(seq):
|
| 110 |
tok, model, device = _load_model_and_tokenizer()
|
| 111 |
n = len(seq)
|
| 112 |
prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
|
| 113 |
-
|
| 114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0)
|
| 116 |
procs = LogitsProcessorList([proc])
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
| 118 |
cur = inputs["input_ids"]
|
| 119 |
-
generated = []
|
| 120 |
|
|
|
|
| 121 |
with torch.no_grad():
|
| 122 |
for _ in range(n):
|
| 123 |
out = model(cur)
|
|
@@ -125,11 +158,11 @@ def _generate_db(seq):
|
|
| 125 |
for p in procs:
|
| 126 |
logits = p(cur, logits)
|
| 127 |
next_id = _top_p_sample(logits, top_p=0.9, temperature=0.8)
|
| 128 |
-
next_id = next_id.to(device)
|
| 129 |
tokid = next_id.item()
|
| 130 |
generated.append(tokid)
|
| 131 |
proc.update(tokid)
|
| 132 |
-
|
|
|
|
| 133 |
|
| 134 |
text = tok.decode(generated, skip_special_tokens=True)
|
| 135 |
db = "".join(c for c in text if c in "().")[:n]
|
|
@@ -137,45 +170,56 @@ def _generate_db(seq):
|
|
| 137 |
db = (db + "." * n)[:n]
|
| 138 |
return db
|
| 139 |
|
| 140 |
-
|
| 141 |
-
# --- structural element translation ---
|
| 142 |
def dotbracket_to_structural(dot_str):
|
| 143 |
-
if not dot_str
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
add(
|
| 160 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
res.append("<end>")
|
| 162 |
return "".join(res)
|
| 163 |
|
| 164 |
-
#
|
| 165 |
-
|
| 166 |
-
seq=(seq or "").strip().upper()
|
| 167 |
-
if not seq or not set(seq)<={"A","U","C","G"}:
|
| 168 |
-
return "Please enter an RNA sequence (A/U/C/G)."
|
| 169 |
-
db=_generate_db(seq)
|
| 170 |
-
return dotbracket_to_structural(db)
|
| 171 |
|
| 172 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 173 |
fn=predict,
|
| 174 |
-
inputs=gr.Textbox(lines=4,label="RNA Sequence (A/U/C/G)",value="GGGAAUCC"),
|
| 175 |
-
outputs=gr.Textbox(lines=6,label="Predicted Structural Elements"),
|
| 176 |
title="RNA Structure Predictor",
|
| 177 |
description="Outputs <start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>."
|
| 178 |
)
|
| 179 |
|
| 180 |
-
if __name__=="__main__":
|
| 181 |
demo.launch()
|
|
|
|
| 9 |
LogitsProcessorList,
|
| 10 |
)
|
| 11 |
|
| 12 |
+
# ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 13 |
MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
|
| 14 |
|
| 15 |
+
# ββ Model loading βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 16 |
@lru_cache(maxsize=1)
|
| 17 |
def _load_model_and_tokenizer():
|
| 18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 23 |
device_map="auto" if device == "cuda" else None,
|
| 24 |
)
|
| 25 |
model.eval()
|
| 26 |
+
if device != "cuda":
|
| 27 |
+
model.to(device)
|
| 28 |
return tokenizer, model, device
|
| 29 |
|
| 30 |
+
# ββ Utility helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 31 |
def _char_token_id(tokenizer, ch: str) -> int:
|
| 32 |
+
# Prefer an exact single-char token if it exists
|
| 33 |
ids = tokenizer.encode(ch, add_special_tokens=False)
|
| 34 |
for tid in ids:
|
| 35 |
if tokenizer.decode([tid]) == ch:
|
| 36 |
return tid
|
| 37 |
+
# Fallback: scan vocab for an exact decode match
|
| 38 |
for tid in range(len(tokenizer)):
|
| 39 |
if tokenizer.decode([tid]) == ch:
|
| 40 |
return tid
|
| 41 |
+
raise ValueError(f"Could not find token id for {ch!r}")
|
| 42 |
|
| 43 |
def _can_pair(a, b, allow_gu=True):
|
| 44 |
+
if (a, b) in [("A","U"),("U","A"),("G","C"),("C","G")]:
|
| 45 |
return True
|
| 46 |
+
if allow_gu and (a, b) in [("G","U"),("U","G")]:
|
| 47 |
return True
|
| 48 |
return False
|
| 49 |
|
| 50 |
def _precompute_can_open(seq, min_loop=3, allow_gu=True):
|
| 51 |
+
n = len(seq)
|
| 52 |
+
can = [False] * n
|
| 53 |
for i in range(n):
|
| 54 |
+
for j in range(i + min_loop + 1, n):
|
| 55 |
+
if _can_pair(seq[i], seq[j], allow_gu):
|
| 56 |
+
can[i] = True
|
| 57 |
break
|
| 58 |
return can
|
| 59 |
|
| 60 |
+
# ββ Constrained processor βββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 61 |
class BalancedParenProcessor(LogitsProcessor):
|
| 62 |
+
"""
|
| 63 |
+
Restricts next token to one of: '(', ')' or '.', while maintaining balance
|
| 64 |
+
and leaving room to close opened stems. No dot bias by default.
|
| 65 |
+
"""
|
| 66 |
def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
|
| 67 |
dot_bias=0.0, paren_penalty=0.0, window=5):
|
| 68 |
self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
|
| 69 |
self.total_len = total_len
|
| 70 |
self.step = 0
|
| 71 |
self.depth = 0
|
| 72 |
+
self.history = []
|
| 73 |
self.can_open = can_open
|
| 74 |
+
self.dot_bias = dot_bias
|
| 75 |
+
self.paren_penalty = paren_penalty
|
| 76 |
+
self.window = window
|
| 77 |
|
| 78 |
def __call__(self, input_ids, scores):
|
| 79 |
+
# Mask everything except allowed tokens
|
| 80 |
mask = torch.full_like(scores, float("-inf"))
|
| 81 |
remaining = self.total_len - self.step
|
| 82 |
allowed = []
|
|
|
|
|
|
|
| 83 |
must_close = (remaining == self.depth and self.depth > 0)
|
| 84 |
pos = self.step
|
| 85 |
|
|
|
|
| 88 |
else:
|
| 89 |
if self.depth > 0:
|
| 90 |
allowed.append(self.rp_id)
|
| 91 |
+
# Allow opening if there will still be room to close later
|
|
|
|
|
|
|
| 92 |
if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]:
|
| 93 |
allowed.append(self.lp_id)
|
|
|
|
| 94 |
allowed.append(self.dot_id)
|
| 95 |
|
| 96 |
mask[:, allowed] = 0.0
|
| 97 |
scores = scores + mask
|
| 98 |
|
|
|
|
| 99 |
if self.dot_bias != 0.0:
|
| 100 |
scores[:, self.dot_id] += self.dot_bias
|
| 101 |
|
| 102 |
+
# Optional mild anti-run for long paren streaks
|
| 103 |
if self.paren_penalty and len(self.history) >= self.window and all(
|
| 104 |
t in (self.lp_id, self.rp_id) for t in self.history[-self.window:]
|
| 105 |
):
|
|
|
|
| 108 |
|
| 109 |
return scores
|
| 110 |
|
| 111 |
+
def update(self, tok):
|
| 112 |
+
if tok == self.lp_id:
|
| 113 |
+
self.depth += 1
|
| 114 |
+
elif tok == self.rp_id:
|
| 115 |
+
self.depth = max(0, self.depth - 1)
|
| 116 |
+
self.history.append(tok)
|
| 117 |
+
self.step += 1
|
| 118 |
+
|
| 119 |
+
def _top_p_sample(logits, top_p=0.9, temperature=0.8):
|
| 120 |
+
logits = logits / temperature
|
| 121 |
+
probs = torch.softmax(logits, dim=-1)
|
| 122 |
+
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
|
| 123 |
+
cumsum = torch.cumsum(sorted_probs, dim=-1)
|
| 124 |
+
mask = cumsum > top_p
|
| 125 |
+
mask[..., 0] = False
|
| 126 |
+
sorted_probs[mask] = 0
|
| 127 |
+
sorted_probs = sorted_probs / sorted_probs.sum(dim=-1, keepdim=True)
|
| 128 |
+
idx = torch.multinomial(sorted_probs, 1)
|
| 129 |
+
return sorted_idx.gather(-1, idx).squeeze(-1)
|
| 130 |
+
|
| 131 |
+
# ββ Generator βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 132 |
def _generate_db(seq):
|
| 133 |
tok, model, device = _load_model_and_tokenizer()
|
| 134 |
n = len(seq)
|
| 135 |
prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
|
| 136 |
+
|
| 137 |
+
lp = _char_token_id(tok, "(")
|
| 138 |
+
rp = _char_token_id(tok, ")")
|
| 139 |
+
dot = _char_token_id(tok, ".")
|
| 140 |
+
|
| 141 |
+
# Helpful to verify once in logs
|
| 142 |
+
print("Token IDs:", {"(": lp, ")": rp, ".": dot})
|
| 143 |
+
|
| 144 |
+
can = _precompute_can_open(seq, min_loop=3, allow_gu=True)
|
| 145 |
proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0)
|
| 146 |
procs = LogitsProcessorList([proc])
|
| 147 |
+
|
| 148 |
+
inputs = tok(prompt, return_tensors="pt")
|
| 149 |
+
# Keep tensors on the same device as the model
|
| 150 |
+
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
| 151 |
cur = inputs["input_ids"]
|
|
|
|
| 152 |
|
| 153 |
+
generated = []
|
| 154 |
with torch.no_grad():
|
| 155 |
for _ in range(n):
|
| 156 |
out = model(cur)
|
|
|
|
| 158 |
for p in procs:
|
| 159 |
logits = p(cur, logits)
|
| 160 |
next_id = _top_p_sample(logits, top_p=0.9, temperature=0.8)
|
|
|
|
| 161 |
tokid = next_id.item()
|
| 162 |
generated.append(tokid)
|
| 163 |
proc.update(tokid)
|
| 164 |
+
# Make sure we append on the SAME device as cur/model
|
| 165 |
+
cur = torch.cat([cur, next_id.view(1, 1).to(cur.device)], dim=1)
|
| 166 |
|
| 167 |
text = tok.decode(generated, skip_special_tokens=True)
|
| 168 |
db = "".join(c for c in text if c in "().")[:n]
|
|
|
|
| 170 |
db = (db + "." * n)[:n]
|
| 171 |
return db
|
| 172 |
|
| 173 |
+
# ββ Structural element translation ββββββββββββββββββββββββββββββββββββββββββββββ
|
|
|
|
| 174 |
def dotbracket_to_structural(dot_str):
|
| 175 |
+
if not dot_str or not isinstance(dot_str, str):
|
| 176 |
+
return "<start><external_loop><end>"
|
| 177 |
+
res = ["<start>"]; depth = 0; i = 0; n = len(dot_str)
|
| 178 |
+
|
| 179 |
+
def add(tag):
|
| 180 |
+
if res[-1] != tag:
|
| 181 |
+
res.append(tag)
|
| 182 |
+
|
| 183 |
+
while i < n:
|
| 184 |
+
c = dot_str[i]
|
| 185 |
+
if c == ".":
|
| 186 |
+
j = i
|
| 187 |
+
while j < n and dot_str[j] == ".":
|
| 188 |
+
j += 1
|
| 189 |
+
nextc = dot_str[j] if j < n else None
|
| 190 |
+
tag = "<external_loop>" if depth == 0 else ("<hairpin>" if nextc == ")" else "<internal_loop>")
|
| 191 |
+
add(tag); i = j; continue
|
| 192 |
+
if c == "(":
|
| 193 |
+
add("<stem>"); depth += 1
|
| 194 |
+
else: # ')'
|
| 195 |
+
add("<stem>"); depth = max(0, depth - 1)
|
| 196 |
+
i += 1
|
| 197 |
+
|
| 198 |
res.append("<end>")
|
| 199 |
return "".join(res)
|
| 200 |
|
| 201 |
+
# ββ Gradio wrapper ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 202 |
+
import traceback
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
|
| 204 |
+
def predict(seq):
|
| 205 |
+
try:
|
| 206 |
+
seq = (seq or "").strip().upper()
|
| 207 |
+
if not seq or not set(seq) <= {"A", "U", "C", "G"}:
|
| 208 |
+
return "Please enter an RNA sequence (A/U/C/G)."
|
| 209 |
+
db = _generate_db(seq)
|
| 210 |
+
return dotbracket_to_structural(db)
|
| 211 |
+
except Exception as e:
|
| 212 |
+
# Print full traceback to Space logs and show a concise error in UI
|
| 213 |
+
traceback.print_exc()
|
| 214 |
+
return f"Error: {type(e).__name__}: {e}"
|
| 215 |
+
|
| 216 |
+
demo = gr.Interface(
|
| 217 |
fn=predict,
|
| 218 |
+
inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)", value="GGGAAUCC"),
|
| 219 |
+
outputs=gr.Textbox(lines=6, label="Predicted Structural Elements"),
|
| 220 |
title="RNA Structure Predictor",
|
| 221 |
description="Outputs <start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>."
|
| 222 |
)
|
| 223 |
|
| 224 |
+
if __name__ == "__main__":
|
| 225 |
demo.launch()
|