Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,10 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
-
import re
|
| 3 |
-
import math
|
| 4 |
import torch
|
| 5 |
import gradio as gr
|
| 6 |
from functools import lru_cache
|
| 7 |
-
from collections import deque
|
| 8 |
from transformers import (
|
| 9 |
AutoTokenizer,
|
| 10 |
AutoModelForCausalLM,
|
|
@@ -12,22 +9,8 @@ from transformers import (
|
|
| 12 |
LogitsProcessorList,
|
| 13 |
)
|
| 14 |
|
| 15 |
-
# --------------------------- Config ---------------------------------
|
| 16 |
MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
|
| 17 |
|
| 18 |
-
# sampling / decoding knobs (tune as you like)
|
| 19 |
-
TEMPERATURE = float(os.getenv("TEMP", "0.8"))
|
| 20 |
-
TOP_P = float(os.getenv("TOP_P", "0.9"))
|
| 21 |
-
DOT_BIAS = float(os.getenv("DOT_BIAS", "0.8")) # +logit added to '.'
|
| 22 |
-
PAREN_RUN_PENALTY = float(os.getenv("PAREN_PEN", "0.5")) # -logit if long run
|
| 23 |
-
PAREN_RUN_WINDOW = int(os.getenv("PAREN_WIN", "5")) # lookback tokens
|
| 24 |
-
|
| 25 |
-
# RNA constraints
|
| 26 |
-
ALLOW_GU = os.getenv("ALLOW_GU", "1") != "0"
|
| 27 |
-
MIN_LOOP_LEN = int(os.getenv("MIN_LOOP", "3")) # minimum unpaired bases in loop
|
| 28 |
-
|
| 29 |
-
# --------------------------------------------------------------------
|
| 30 |
-
|
| 31 |
@lru_cache(maxsize=1)
|
| 32 |
def _load_model_and_tokenizer():
|
| 33 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
@@ -40,25 +23,7 @@ def _load_model_and_tokenizer():
|
|
| 40 |
model.eval()
|
| 41 |
return tokenizer, model, device
|
| 42 |
|
| 43 |
-
|
| 44 |
-
RNA: GGGAAAUCCCU
|
| 45 |
-
Dot-bracket: (((...)))).
|
| 46 |
-
RNA: AUAUAUAU
|
| 47 |
-
Dot-bracket: ........
|
| 48 |
-
RNA: GGGAAACCC
|
| 49 |
-
Dot-bracket: (((...)))
|
| 50 |
-
RNA: GAAACUU
|
| 51 |
-
Dot-bracket: (..())
|
| 52 |
-
"""
|
| 53 |
-
|
| 54 |
-
def _make_prompt(seq: str) -> str:
|
| 55 |
-
n = len(seq)
|
| 56 |
-
return (
|
| 57 |
-
FEWSHOT
|
| 58 |
-
+ f"\nRNA: {seq}\n"
|
| 59 |
-
+ f"Dot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
|
| 60 |
-
)
|
| 61 |
-
|
| 62 |
def _char_token_id(tokenizer, ch: str) -> int:
|
| 63 |
ids = tokenizer.encode(ch, add_special_tokens=False)
|
| 64 |
for tid in ids:
|
|
@@ -67,261 +32,148 @@ def _char_token_id(tokenizer, ch: str) -> int:
|
|
| 67 |
for tid in range(len(tokenizer)):
|
| 68 |
if tokenizer.decode([tid]) == ch:
|
| 69 |
return tid
|
| 70 |
-
raise ValueError(f"Could not find token id for
|
| 71 |
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
if
|
| 76 |
-
if a == "G" and b == "C": return True
|
| 77 |
-
if a == "C" and b == "G": return True
|
| 78 |
-
if allow_gu and ((a == "G" and b == "U") or (a == "U" and b == "G")):
|
| 79 |
return True
|
| 80 |
return False
|
| 81 |
|
| 82 |
-
def _precompute_can_open(seq
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
pos_by_base = {"A": [], "U": [], "G": [], "C": []}
|
| 90 |
-
for idx in range(n-1, -1, -1):
|
| 91 |
-
base = seq[idx]
|
| 92 |
-
# update future lists (right side of idx)
|
| 93 |
-
if idx + 1 < n:
|
| 94 |
-
pos_by_base[seq[idx+1]].append(idx+1)
|
| 95 |
-
# Check if any future partner exists with min loop
|
| 96 |
-
min_j = idx + min_loop_len + 1
|
| 97 |
-
ok = False
|
| 98 |
-
for b, lst in pos_by_base.items():
|
| 99 |
-
if any(j >= min_j and _can_pair(base, b, allow_gu) for j in lst):
|
| 100 |
-
ok = True
|
| 101 |
break
|
| 102 |
-
|
| 103 |
-
return can_open
|
| 104 |
-
|
| 105 |
-
# --------- Top-p sampling ----------
|
| 106 |
-
def _top_p_sample_from_logits(logits: torch.Tensor, top_p: float, temperature: float) -> torch.Tensor:
|
| 107 |
-
if temperature <= 0:
|
| 108 |
-
# fall back to greedy
|
| 109 |
-
return torch.argmax(logits, dim=-1)
|
| 110 |
-
logits = logits / temperature
|
| 111 |
-
probs = torch.softmax(logits, dim=-1)
|
| 112 |
-
|
| 113 |
-
# sort
|
| 114 |
-
sorted_probs, sorted_idx = torch.sort(probs, descending=True, dim=-1)
|
| 115 |
-
cumulative = torch.cumsum(sorted_probs, dim=-1)
|
| 116 |
-
|
| 117 |
-
# mask tokens beyond top_p
|
| 118 |
-
cutoff = (cumulative > top_p).float()
|
| 119 |
-
# ensure at least one token kept
|
| 120 |
-
cutoff[..., 0] = 0.0
|
| 121 |
-
keep = 1.0 - cutoff
|
| 122 |
-
|
| 123 |
-
filtered_probs = sorted_probs * keep
|
| 124 |
-
# renormalize
|
| 125 |
-
filtered_probs = filtered_probs / filtered_probs.sum(dim=-1, keepdim=True)
|
| 126 |
|
| 127 |
-
|
| 128 |
-
next_sorted = torch.multinomial(filtered_probs, num_samples=1)
|
| 129 |
-
next_ids = sorted_idx.gather(-1, next_sorted)
|
| 130 |
-
return next_ids.squeeze(-1)
|
| 131 |
-
|
| 132 |
-
# --- Finite-State logits processor with RNA constraints ---
|
| 133 |
class BalancedParenProcessor(LogitsProcessor):
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
- Forbid '(' if no feasible partner ahead (RNA pairing + min loop)
|
| 139 |
-
- Dot bias (+bias to dot logit)
|
| 140 |
-
- Penalize long runs of parentheses
|
| 141 |
-
"""
|
| 142 |
-
def __init__(self, lp_id: int, rp_id: int, dot_id: int, total_len: int,
|
| 143 |
-
can_open, paren_run_window=5, paren_run_penalty=0.5, dot_bias=0.8):
|
| 144 |
-
self.lp_id = int(lp_id)
|
| 145 |
-
self.rp_id = int(rp_id)
|
| 146 |
-
self.dot_id = int(dot_id)
|
| 147 |
-
self.total_len = int(total_len)
|
| 148 |
-
|
| 149 |
self.step = 0
|
| 150 |
self.depth = 0
|
| 151 |
-
self.
|
| 152 |
-
self.
|
| 153 |
-
self.dot_bias
|
| 154 |
-
self.
|
| 155 |
-
|
| 156 |
-
def update_with_chosen(self, token_id: int):
|
| 157 |
-
# called after each step by the generator
|
| 158 |
-
self.recent.append(token_id)
|
| 159 |
-
if token_id == self.lp_id:
|
| 160 |
-
self.depth += 1
|
| 161 |
-
elif token_id == self.rp_id:
|
| 162 |
-
self.depth = max(0, self.depth - 1)
|
| 163 |
-
self.step += 1
|
| 164 |
-
|
| 165 |
-
def _recent_is_paren_run(self):
|
| 166 |
-
if not self.recent:
|
| 167 |
-
return False
|
| 168 |
-
return all(t in (self.lp_id, self.rp_id) for t in self.recent)
|
| 169 |
|
| 170 |
def __call__(self, input_ids, scores):
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
remaining
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
pos = self.step
|
| 179 |
-
|
| 180 |
-
if must_close_all:
|
| 181 |
-
allowed = [self.rp_id]
|
| 182 |
else:
|
| 183 |
-
|
| 184 |
-
if self.depth > 0:
|
| 185 |
allowed.append(self.rp_id)
|
| 186 |
-
|
| 187 |
-
# Allow '(' only if (a) room to close by end and (b) feasible partner ahead
|
| 188 |
-
if remaining - 1 > self.depth and pos < self.total_len and self.can_open[pos]:
|
| 189 |
allowed.append(self.lp_id)
|
| 190 |
-
|
| 191 |
-
# '.' generally allowed
|
| 192 |
allowed.append(self.dot_id)
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
scores
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
# penalize long paren run
|
| 201 |
-
if self._recent_is_paren_run():
|
| 202 |
-
scores[..., self.lp_id] = scores[..., self.lp_id] - self.paren_run_penalty
|
| 203 |
-
scores[..., self.rp_id] = scores[..., self.rp_id] - self.paren_run_penalty
|
| 204 |
-
|
| 205 |
return scores
|
| 206 |
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
)
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
|
| 239 |
-
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
| 248 |
-
|
| 249 |
-
|
| 250 |
-
|
|
|
|
|
|
|
|
|
|
| 251 |
return db
|
| 252 |
|
| 253 |
-
# ---
|
| 254 |
-
def dotbracket_to_structural(dot_str
|
| 255 |
-
if not dot_str:
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
append_once(label)
|
| 274 |
-
i = j
|
| 275 |
-
continue
|
| 276 |
-
if c == '(':
|
| 277 |
-
append_once("<stem>")
|
| 278 |
-
depth += 1
|
| 279 |
-
else: # ')'
|
| 280 |
-
append_once("<stem>")
|
| 281 |
-
depth = max(depth - 1, 0)
|
| 282 |
-
i += 1
|
| 283 |
-
|
| 284 |
res.append("<end>")
|
| 285 |
return "".join(res)
|
| 286 |
|
| 287 |
-
# --- Gradio
|
| 288 |
-
def predict(seq
|
| 289 |
-
seq
|
| 290 |
-
if not seq or not set(seq)
|
| 291 |
return "Please enter an RNA sequence (A/U/C/G)."
|
| 292 |
-
|
| 293 |
-
db = _generate_db(seq)
|
| 294 |
return dotbracket_to_structural(db)
|
| 295 |
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
gr.
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
t = gr.Slider(0.1, 1.5, value=TEMPERATURE, step=0.05, label="Temperature")
|
| 304 |
-
p = gr.Slider(0.5, 1.0, value=TOP_P, step=0.01, label="Top-p")
|
| 305 |
-
dbias = gr.Slider(0.0, 2.0, value=DOT_BIAS, step=0.05, label="Dot bias (+logit)")
|
| 306 |
-
looplen = gr.Slider(0, 5, value=MIN_LOOP_LEN, step=1, label="Min loop length")
|
| 307 |
-
wobble = gr.Checkbox(value=ALLOW_GU, label="Allow GU wobble")
|
| 308 |
-
btn = gr.Button("Submit", variant="primary")
|
| 309 |
-
|
| 310 |
-
def _predict_with_knobs(seq, temperature, topp, dot_bias, min_loop, allow_gu):
|
| 311 |
-
global TEMPERATURE, TOP_P, DOT_BIAS, MIN_LOOP_LEN, ALLOW_GU
|
| 312 |
-
TEMPERATURE = float(temperature)
|
| 313 |
-
TOP_P = float(topp)
|
| 314 |
-
DOT_BIAS = float(dot_bias)
|
| 315 |
-
MIN_LOOP_LEN = int(min_loop)
|
| 316 |
-
# this affects precompute_can_open on next call
|
| 317 |
-
ALLOW_GU = bool(allow_gu)
|
| 318 |
-
return predict(seq)
|
| 319 |
-
|
| 320 |
-
btn.click(
|
| 321 |
-
_predict_with_knobs,
|
| 322 |
-
inputs=[seq_in, t, p, dbias, looplen, wobble],
|
| 323 |
-
outputs=[out],
|
| 324 |
-
)
|
| 325 |
|
| 326 |
-
if __name__
|
| 327 |
demo.launch()
|
|
|
|
| 1 |
import os
|
|
|
|
|
|
|
| 2 |
import torch
|
| 3 |
import gradio as gr
|
| 4 |
from functools import lru_cache
|
|
|
|
| 5 |
from transformers import (
|
| 6 |
AutoTokenizer,
|
| 7 |
AutoModelForCausalLM,
|
|
|
|
| 9 |
LogitsProcessorList,
|
| 10 |
)
|
| 11 |
|
|
|
|
| 12 |
MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
@lru_cache(maxsize=1)
|
| 15 |
def _load_model_and_tokenizer():
|
| 16 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
| 23 |
model.eval()
|
| 24 |
return tokenizer, model, device
|
| 25 |
|
| 26 |
+
# --- Utility helpers ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
def _char_token_id(tokenizer, ch: str) -> int:
|
| 28 |
ids = tokenizer.encode(ch, add_special_tokens=False)
|
| 29 |
for tid in ids:
|
|
|
|
| 32 |
for tid in range(len(tokenizer)):
|
| 33 |
if tokenizer.decode([tid]) == ch:
|
| 34 |
return tid
|
| 35 |
+
raise ValueError(f"Could not find token id for {ch}")
|
| 36 |
|
| 37 |
+
def _can_pair(a, b, allow_gu=True):
|
| 38 |
+
if (a,b) in [("A","U"),("U","A"),("G","C"),("C","G")]:
|
| 39 |
+
return True
|
| 40 |
+
if allow_gu and (a,b) in [("G","U"),("U","G")]:
|
|
|
|
|
|
|
|
|
|
| 41 |
return True
|
| 42 |
return False
|
| 43 |
|
| 44 |
+
def _precompute_can_open(seq, min_loop=3, allow_gu=True):
|
| 45 |
+
n=len(seq)
|
| 46 |
+
can=[False]*n
|
| 47 |
+
for i in range(n):
|
| 48 |
+
for j in range(i+min_loop+1,n):
|
| 49 |
+
if _can_pair(seq[i],seq[j],allow_gu):
|
| 50 |
+
can[i]=True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
break
|
| 52 |
+
return can
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
+
# --- constrained processor ---
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
class BalancedParenProcessor(LogitsProcessor):
|
| 56 |
+
def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
|
| 57 |
+
dot_bias=0.8, paren_penalty=0.5, window=5):
|
| 58 |
+
self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
|
| 59 |
+
self.total_len = total_len
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
self.step = 0
|
| 61 |
self.depth = 0
|
| 62 |
+
self.history=[]
|
| 63 |
+
self.can_open = can_open
|
| 64 |
+
self.dot_bias=dot_bias
|
| 65 |
+
self.paren_penalty=paren_penalty
|
| 66 |
+
self.window=window
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def __call__(self, input_ids, scores):
|
| 69 |
+
mask=torch.full_like(scores,float("-inf"))
|
| 70 |
+
remaining=self.total_len-self.step
|
| 71 |
+
allowed=[]
|
| 72 |
+
must_close=(remaining==self.depth and self.depth>0)
|
| 73 |
+
pos=self.step
|
| 74 |
+
if must_close:
|
| 75 |
+
allowed=[self.rp_id]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
else:
|
| 77 |
+
if self.depth>0:
|
|
|
|
| 78 |
allowed.append(self.rp_id)
|
| 79 |
+
if remaining-1>self.depth and pos<len(self.can_open) and self.can_open[pos]:
|
|
|
|
|
|
|
| 80 |
allowed.append(self.lp_id)
|
|
|
|
|
|
|
| 81 |
allowed.append(self.dot_id)
|
| 82 |
+
mask[:,allowed]=0.0
|
| 83 |
+
scores=scores+mask
|
| 84 |
+
scores[:,self.dot_id]+=self.dot_bias
|
| 85 |
+
if len(self.history)>=self.window and all(t in (self.lp_id,self.rp_id) for t in self.history[-self.window:]):
|
| 86 |
+
scores[:,self.lp_id]-=self.paren_penalty
|
| 87 |
+
scores[:,self.rp_id]-=self.paren_penalty
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
return scores
|
| 89 |
|
| 90 |
+
def update(self, tok):
|
| 91 |
+
if tok==self.lp_id:
|
| 92 |
+
self.depth+=1
|
| 93 |
+
elif tok==self.rp_id:
|
| 94 |
+
self.depth=max(0,self.depth-1)
|
| 95 |
+
self.history.append(tok)
|
| 96 |
+
self.step+=1
|
| 97 |
+
|
| 98 |
+
def _top_p_sample(logits, top_p=0.9, temperature=0.8):
|
| 99 |
+
logits=logits/temperature
|
| 100 |
+
probs=torch.softmax(logits,dim=-1)
|
| 101 |
+
sorted_probs,sorted_idx=torch.sort(probs,descending=True)
|
| 102 |
+
cum=torch.cumsum(sorted_probs,dim=-1)
|
| 103 |
+
mask=cum>top_p
|
| 104 |
+
mask[...,0]=False
|
| 105 |
+
sorted_probs[mask]=0
|
| 106 |
+
sorted_probs/=sorted_probs.sum(dim=-1,keepdim=True)
|
| 107 |
+
idx=torch.multinomial(sorted_probs,1)
|
| 108 |
+
return sorted_idx.gather(-1,idx).squeeze(-1)
|
| 109 |
+
|
| 110 |
+
# --- generator ---
|
| 111 |
+
def _generate_db(seq):
|
| 112 |
+
tok,model,device=_load_model_and_tokenizer()
|
| 113 |
+
n=len(seq)
|
| 114 |
+
prompt=f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
|
| 115 |
+
lp=_char_token_id(tok,"("); rp=_char_token_id(tok,")"); dot=_char_token_id(tok,".")
|
| 116 |
+
can=_precompute_can_open(seq)
|
| 117 |
+
proc=BalancedParenProcessor(lp,rp,dot,n,can)
|
| 118 |
+
procs=LogitsProcessorList([proc])
|
| 119 |
+
inputs=tok(prompt,return_tensors="pt").to(device)
|
| 120 |
+
cur=inputs["input_ids"]
|
| 121 |
+
generated=[]
|
| 122 |
+
for _ in range(n):
|
| 123 |
+
out=model(cur)
|
| 124 |
+
logits=out.logits[:,-1,:]
|
| 125 |
+
for p in procs:
|
| 126 |
+
logits=p(cur,logits)
|
| 127 |
+
next_id=_top_p_sample(logits,0.9,0.8)
|
| 128 |
+
next_id=next_id.to(device)
|
| 129 |
+
tokid=next_id.item()
|
| 130 |
+
generated.append(tokid)
|
| 131 |
+
proc.update(tokid)
|
| 132 |
+
cur=torch.cat([cur,next_id.view(1,1)],dim=1)
|
| 133 |
+
text=tok.decode(generated,skip_special_tokens=True)
|
| 134 |
+
db="".join(c for c in text if c in "().")[:n]
|
| 135 |
+
if len(db)!=n:
|
| 136 |
+
db=(db+"."*n)[:n]
|
| 137 |
return db
|
| 138 |
|
| 139 |
+
# --- structural element translation ---
|
| 140 |
+
def dotbracket_to_structural(dot_str):
|
| 141 |
+
if not dot_str: return "<start><external_loop><end>"
|
| 142 |
+
res=["<start>"];depth=0;i=0;n=len(dot_str)
|
| 143 |
+
def add(tag):
|
| 144 |
+
if res[-1]!=tag:res.append(tag)
|
| 145 |
+
while i<n:
|
| 146 |
+
c=dot_str[i]
|
| 147 |
+
if c==".":
|
| 148 |
+
j=i
|
| 149 |
+
while j<n and dot_str[j]==".":
|
| 150 |
+
j+=1
|
| 151 |
+
nextc=dot_str[j] if j<n else None
|
| 152 |
+
tag="<external_loop>" if depth==0 else ("<hairpin>" if nextc==")" else "<internal_loop>")
|
| 153 |
+
add(tag);i=j;continue
|
| 154 |
+
if c=="(":
|
| 155 |
+
add("<stem>");depth+=1
|
| 156 |
+
else:
|
| 157 |
+
add("<stem>");depth=max(0,depth-1)
|
| 158 |
+
i+=1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
res.append("<end>")
|
| 160 |
return "".join(res)
|
| 161 |
|
| 162 |
+
# --- Gradio wrapper ---
|
| 163 |
+
def predict(seq):
|
| 164 |
+
seq=(seq or "").strip().upper()
|
| 165 |
+
if not seq or not set(seq)<={"A","U","C","G"}:
|
| 166 |
return "Please enter an RNA sequence (A/U/C/G)."
|
| 167 |
+
db=_generate_db(seq)
|
|
|
|
| 168 |
return dotbracket_to_structural(db)
|
| 169 |
|
| 170 |
+
demo=gr.Interface(
|
| 171 |
+
fn=predict,
|
| 172 |
+
inputs=gr.Textbox(lines=4,label="RNA Sequence (A/U/C/G)",value="GGGAAUCC"),
|
| 173 |
+
outputs=gr.Textbox(lines=6,label="Predicted Structural Elements"),
|
| 174 |
+
title="RNA Structure Predictor",
|
| 175 |
+
description="Outputs <start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>."
|
| 176 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
+
if __name__=="__main__":
|
| 179 |
demo.launch()
|