Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -1,58 +1,31 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
(e.g., '(((..)))'), it outputs or processes RNA secondary structure using *structural element notation*.
|
| 5 |
-
|
| 6 |
-
REQUIREMENTS:
|
| 7 |
-
1. Convert from dot-bracket → structural elements:
|
| 8 |
-
- '(' and ')' (paired bases) should be grouped and labeled as <stem>
|
| 9 |
-
- contiguous '.' regions inside parentheses should be labeled as <hairpin> (if within a stem)
|
| 10 |
-
- contiguous '.' regions outside all parentheses should be labeled as <external_loop>
|
| 11 |
-
- unpaired regions between stems inside parentheses (bulges or internal loops) can be labeled <internal_loop>
|
| 12 |
-
- At start and end of the sequence, prepend and append <start> and <end>
|
| 13 |
-
|
| 14 |
-
2. Example transformation:
|
| 15 |
-
Input:
|
| 16 |
-
RNA: "GCGCGAAAACGCGC"
|
| 17 |
-
Dot-bracket: "(((((....)))))"
|
| 18 |
-
Output:
|
| 19 |
-
Structural notation: "<start><stem><hairpin><stem><end>"
|
| 20 |
-
|
| 21 |
-
3. Implementation details:
|
| 22 |
-
- The program should scan the dot-bracket string left to right.
|
| 23 |
-
- Detect transitions between paired/unpaired regions.
|
| 24 |
-
- Use a stack or counter to track nested stems if needed.
|
| 25 |
-
- Output the element sequence as a string (like '<stem><hairpin><stem><end>').
|
| 26 |
-
|
| 27 |
-
4. Preserve all existing code functionality (file I/O, RNA sequence handling, etc.)
|
| 28 |
-
but replace or augment the output generation with the new structural-element mapping.
|
| 29 |
-
|
| 30 |
-
OPTIONAL:
|
| 31 |
-
- If the code plots or visualizes structures, update the labels to use element names.
|
| 32 |
-
- If multiple structures are processed, apply the transformation for each.
|
| 33 |
-
|
| 34 |
-
COMMENT:
|
| 35 |
-
Insert the conversion logic into a function like:
|
| 36 |
-
def dotbracket_to_structural(dot_str: str) -> str:
|
| 37 |
-
...
|
| 38 |
-
return structural_str
|
| 39 |
-
"""
|
| 40 |
-
|
| 41 |
import gradio as gr
|
|
|
|
| 42 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 43 |
-
import torch, re
|
| 44 |
-
|
| 45 |
-
MODEL_ID = "llm-rna-api-rmit/rna-structure-model" # your uploaded model
|
| 46 |
|
| 47 |
-
|
| 48 |
-
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
|
| 49 |
|
| 50 |
DB_FULL = re.compile(r"^[().]+$")
|
| 51 |
DB_SCAN = re.compile(r"[().]{5,}")
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
outputs = model.generate(
|
| 57 |
**inputs,
|
| 58 |
max_new_tokens=max_new_tokens,
|
|
@@ -74,98 +47,63 @@ def _extract_dotbracket(text, length):
|
|
| 74 |
return None
|
| 75 |
|
| 76 |
def dotbracket_to_structural(dot_str: str) -> str:
|
| 77 |
-
"""
|
| 78 |
-
Convert a dot-bracket string to structural-element notation.
|
| 79 |
-
|
| 80 |
-
Heuristic rules (left-to-right scan):
|
| 81 |
-
- '(' and ')' => <stem>
|
| 82 |
-
- '.' with depth == 0 => <external_loop>
|
| 83 |
-
- '.' with depth > 0:
|
| 84 |
-
lookahead to next non-dot:
|
| 85 |
-
- next == ')' => <hairpin>
|
| 86 |
-
- next == '(' (or None) => <internal_loop>
|
| 87 |
-
Groups contiguous regions and wraps with <start> ... <end>.
|
| 88 |
-
"""
|
| 89 |
n = len(dot_str)
|
| 90 |
res = ["<start>"]
|
| 91 |
depth = 0
|
| 92 |
i = 0
|
| 93 |
|
| 94 |
def append_once(tag: str):
|
| 95 |
-
if
|
| 96 |
res.append(tag)
|
| 97 |
|
| 98 |
while i < n:
|
| 99 |
c = dot_str[i]
|
| 100 |
-
|
| 101 |
if c == '.':
|
| 102 |
-
# consume the entire '.' run
|
| 103 |
j = i
|
| 104 |
while j < n and dot_str[j] == '.':
|
| 105 |
j += 1
|
| 106 |
next_char = dot_str[j] if j < n else None
|
| 107 |
-
|
| 108 |
if depth == 0:
|
| 109 |
label = "<external_loop>"
|
| 110 |
else:
|
| 111 |
-
|
| 112 |
-
# If we see closing parentheses after the dots, treat as hairpin apex.
|
| 113 |
-
# If we see another '(', treat as internal loop/bulge/multiloop entry.
|
| 114 |
-
if next_char == ')':
|
| 115 |
-
label = "<hairpin>"
|
| 116 |
-
else:
|
| 117 |
-
label = "<internal_loop>"
|
| 118 |
-
|
| 119 |
append_once(label)
|
| 120 |
i = j
|
| 121 |
continue
|
| 122 |
-
|
| 123 |
-
# Paired region: '(' or ')'
|
| 124 |
-
# We label both as stem; adjust depth appropriately.
|
| 125 |
if c == '(':
|
| 126 |
append_once("<stem>")
|
| 127 |
depth += 1
|
| 128 |
-
|
| 129 |
append_once("<stem>")
|
| 130 |
-
# Close after labeling so that dots immediately following at lower depth
|
| 131 |
-
# are recognized correctly in the next iteration.
|
| 132 |
depth = max(depth - 1, 0)
|
| 133 |
-
|
| 134 |
i += 1
|
| 135 |
|
| 136 |
res.append("<end>")
|
| 137 |
return "".join(res)
|
| 138 |
|
| 139 |
-
def predict(seq):
|
| 140 |
seq = (seq or "").strip().upper()
|
| 141 |
if not seq or not set(seq) <= {"A","U","C","G"}:
|
| 142 |
return "Please enter an RNA sequence (A/U/C/G)."
|
| 143 |
|
| 144 |
n = len(seq)
|
| 145 |
prompt = f"RNA: {seq}\nDot-bracket structure:"
|
| 146 |
-
text = _generate(prompt, max_new_tokens=n +
|
| 147 |
|
| 148 |
-
# Try to extract a dot-bracket string of the correct length
|
| 149 |
db = _extract_dotbracket(text, n)
|
| 150 |
if db is None:
|
| 151 |
-
# fall back to filtered characters; if still wrong length, echo raw text
|
| 152 |
db_chars = [c for c in text if c in "()."]
|
| 153 |
db = "".join(db_chars) if len(db_chars) == n else None
|
| 154 |
if db is None:
|
| 155 |
-
return text.strip()
|
| 156 |
|
| 157 |
-
|
| 158 |
-
structural = dotbracket_to_structural(db)
|
| 159 |
-
return structural
|
| 160 |
|
| 161 |
demo = gr.Interface(
|
| 162 |
fn=predict,
|
| 163 |
inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)"),
|
| 164 |
outputs=gr.Textbox(lines=6, label="Predicted Structural Elements"),
|
| 165 |
title="RNA Structure Predictor",
|
| 166 |
-
description="
|
| 167 |
)
|
| 168 |
|
| 169 |
-
if __name__ == "__main__":
|
| 170 |
-
demo.launch()
|
| 171 |
-
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import gradio as gr
|
| 5 |
+
from functools import lru_cache
|
| 6 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
+
MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
|
|
|
|
| 9 |
|
| 10 |
DB_FULL = re.compile(r"^[().]+$")
|
| 11 |
DB_SCAN = re.compile(r"[().]{5,}")
|
| 12 |
|
| 13 |
+
@lru_cache(maxsize=1)
|
| 14 |
+
def _load_model_and_tokenizer():
|
| 15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
|
| 17 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 18 |
+
MODEL_ID,
|
| 19 |
+
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
|
| 20 |
+
device_map="auto" if device == "cuda" else None,
|
| 21 |
+
)
|
| 22 |
+
model.eval()
|
| 23 |
+
return tokenizer, model, device
|
| 24 |
+
|
| 25 |
+
def _generate(prompt, max_new_tokens=256, temperature=0.0):
|
| 26 |
+
tokenizer, model, device = _load_model_and_tokenizer()
|
| 27 |
+
with torch.inference_mode():
|
| 28 |
+
inputs = tokenizer(prompt, return_tensors="pt").to(device)
|
| 29 |
outputs = model.generate(
|
| 30 |
**inputs,
|
| 31 |
max_new_tokens=max_new_tokens,
|
|
|
|
| 47 |
return None
|
| 48 |
|
| 49 |
def dotbracket_to_structural(dot_str: str) -> str:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 50 |
n = len(dot_str)
|
| 51 |
res = ["<start>"]
|
| 52 |
depth = 0
|
| 53 |
i = 0
|
| 54 |
|
| 55 |
def append_once(tag: str):
|
| 56 |
+
if res[-1] != tag:
|
| 57 |
res.append(tag)
|
| 58 |
|
| 59 |
while i < n:
|
| 60 |
c = dot_str[i]
|
|
|
|
| 61 |
if c == '.':
|
|
|
|
| 62 |
j = i
|
| 63 |
while j < n and dot_str[j] == '.':
|
| 64 |
j += 1
|
| 65 |
next_char = dot_str[j] if j < n else None
|
|
|
|
| 66 |
if depth == 0:
|
| 67 |
label = "<external_loop>"
|
| 68 |
else:
|
| 69 |
+
label = "<hairpin>" if next_char == ')' else "<internal_loop>"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
append_once(label)
|
| 71 |
i = j
|
| 72 |
continue
|
|
|
|
|
|
|
|
|
|
| 73 |
if c == '(':
|
| 74 |
append_once("<stem>")
|
| 75 |
depth += 1
|
| 76 |
+
else: # ')'
|
| 77 |
append_once("<stem>")
|
|
|
|
|
|
|
| 78 |
depth = max(depth - 1, 0)
|
|
|
|
| 79 |
i += 1
|
| 80 |
|
| 81 |
res.append("<end>")
|
| 82 |
return "".join(res)
|
| 83 |
|
| 84 |
+
def predict(seq: str):
|
| 85 |
seq = (seq or "").strip().upper()
|
| 86 |
if not seq or not set(seq) <= {"A","U","C","G"}:
|
| 87 |
return "Please enter an RNA sequence (A/U/C/G)."
|
| 88 |
|
| 89 |
n = len(seq)
|
| 90 |
prompt = f"RNA: {seq}\nDot-bracket structure:"
|
| 91 |
+
text = _generate(prompt, max_new_tokens=n + 32, temperature=0.0)
|
| 92 |
|
|
|
|
| 93 |
db = _extract_dotbracket(text, n)
|
| 94 |
if db is None:
|
|
|
|
| 95 |
db_chars = [c for c in text if c in "()."]
|
| 96 |
db = "".join(db_chars) if len(db_chars) == n else None
|
| 97 |
if db is None:
|
| 98 |
+
return text.strip()
|
| 99 |
|
| 100 |
+
return dotbracket_to_structural(db)
|
|
|
|
|
|
|
| 101 |
|
| 102 |
demo = gr.Interface(
|
| 103 |
fn=predict,
|
| 104 |
inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)"),
|
| 105 |
outputs=gr.Textbox(lines=6, label="Predicted Structural Elements"),
|
| 106 |
title="RNA Structure Predictor",
|
| 107 |
+
description="Outputs structural-element notation: <start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>."
|
| 108 |
)
|
| 109 |
|
|
|
|
|
|
|
|
|