1-1-3-8 commited on
Commit
b0f91f5
·
verified ·
1 Parent(s): 4242716

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -91
app.py CHANGED
@@ -1,58 +1,31 @@
1
- """
2
- TASK:
3
- Modify this RNA secondary-structure analysis code so that instead of using dot-bracket notation
4
- (e.g., '(((..)))'), it outputs or processes RNA secondary structure using *structural element notation*.
5
-
6
- REQUIREMENTS:
7
- 1. Convert from dot-bracket → structural elements:
8
- - '(' and ')' (paired bases) should be grouped and labeled as <stem>
9
- - contiguous '.' regions inside parentheses should be labeled as <hairpin> (if within a stem)
10
- - contiguous '.' regions outside all parentheses should be labeled as <external_loop>
11
- - unpaired regions between stems inside parentheses (bulges or internal loops) can be labeled <internal_loop>
12
- - At start and end of the sequence, prepend and append <start> and <end>
13
-
14
- 2. Example transformation:
15
- Input:
16
- RNA: "GCGCGAAAACGCGC"
17
- Dot-bracket: "(((((....)))))"
18
- Output:
19
- Structural notation: "<start><stem><hairpin><stem><end>"
20
-
21
- 3. Implementation details:
22
- - The program should scan the dot-bracket string left to right.
23
- - Detect transitions between paired/unpaired regions.
24
- - Use a stack or counter to track nested stems if needed.
25
- - Output the element sequence as a string (like '<stem><hairpin><stem><end>').
26
-
27
- 4. Preserve all existing code functionality (file I/O, RNA sequence handling, etc.)
28
- but replace or augment the output generation with the new structural-element mapping.
29
-
30
- OPTIONAL:
31
- - If the code plots or visualizes structures, update the labels to use element names.
32
- - If multiple structures are processed, apply the transformation for each.
33
-
34
- COMMENT:
35
- Insert the conversion logic into a function like:
36
- def dotbracket_to_structural(dot_str: str) -> str:
37
- ...
38
- return structural_str
39
- """
40
-
41
  import gradio as gr
 
42
  from transformers import AutoTokenizer, AutoModelForCausalLM
43
- import torch, re
44
-
45
- MODEL_ID = "llm-rna-api-rmit/rna-structure-model" # your uploaded model
46
 
47
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
48
- model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
49
 
50
  DB_FULL = re.compile(r"^[().]+$")
51
  DB_SCAN = re.compile(r"[().]{5,}")
52
 
53
- def _generate(prompt, max_new_tokens=512, temperature=0.0):
54
- with torch.no_grad():
55
- inputs = tokenizer(prompt, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  outputs = model.generate(
57
  **inputs,
58
  max_new_tokens=max_new_tokens,
@@ -74,98 +47,63 @@ def _extract_dotbracket(text, length):
74
  return None
75
 
76
  def dotbracket_to_structural(dot_str: str) -> str:
77
- """
78
- Convert a dot-bracket string to structural-element notation.
79
-
80
- Heuristic rules (left-to-right scan):
81
- - '(' and ')' => <stem>
82
- - '.' with depth == 0 => <external_loop>
83
- - '.' with depth > 0:
84
- lookahead to next non-dot:
85
- - next == ')' => <hairpin>
86
- - next == '(' (or None) => <internal_loop>
87
- Groups contiguous regions and wraps with <start> ... <end>.
88
- """
89
  n = len(dot_str)
90
  res = ["<start>"]
91
  depth = 0
92
  i = 0
93
 
94
  def append_once(tag: str):
95
- if not res or res[-1] != tag:
96
  res.append(tag)
97
 
98
  while i < n:
99
  c = dot_str[i]
100
-
101
  if c == '.':
102
- # consume the entire '.' run
103
  j = i
104
  while j < n and dot_str[j] == '.':
105
  j += 1
106
  next_char = dot_str[j] if j < n else None
107
-
108
  if depth == 0:
109
  label = "<external_loop>"
110
  else:
111
- # Inside a stemmed region:
112
- # If we see closing parentheses after the dots, treat as hairpin apex.
113
- # If we see another '(', treat as internal loop/bulge/multiloop entry.
114
- if next_char == ')':
115
- label = "<hairpin>"
116
- else:
117
- label = "<internal_loop>"
118
-
119
  append_once(label)
120
  i = j
121
  continue
122
-
123
- # Paired region: '(' or ')'
124
- # We label both as stem; adjust depth appropriately.
125
  if c == '(':
126
  append_once("<stem>")
127
  depth += 1
128
- elif c == ')':
129
  append_once("<stem>")
130
- # Close after labeling so that dots immediately following at lower depth
131
- # are recognized correctly in the next iteration.
132
  depth = max(depth - 1, 0)
133
-
134
  i += 1
135
 
136
  res.append("<end>")
137
  return "".join(res)
138
 
139
- def predict(seq):
140
  seq = (seq or "").strip().upper()
141
  if not seq or not set(seq) <= {"A","U","C","G"}:
142
  return "Please enter an RNA sequence (A/U/C/G)."
143
 
144
  n = len(seq)
145
  prompt = f"RNA: {seq}\nDot-bracket structure:"
146
- text = _generate(prompt, max_new_tokens=n + 20, temperature=0.0)
147
 
148
- # Try to extract a dot-bracket string of the correct length
149
  db = _extract_dotbracket(text, n)
150
  if db is None:
151
- # fall back to filtered characters; if still wrong length, echo raw text
152
  db_chars = [c for c in text if c in "()."]
153
  db = "".join(db_chars) if len(db_chars) == n else None
154
  if db is None:
155
- return text.strip() # preserve existing behavior on extraction failure
156
 
157
- # Convert to structural-element notation
158
- structural = dotbracket_to_structural(db)
159
- return structural
160
 
161
  demo = gr.Interface(
162
  fn=predict,
163
  inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)"),
164
  outputs=gr.Textbox(lines=6, label="Predicted Structural Elements"),
165
  title="RNA Structure Predictor",
166
- description="Uses your fine-tuned model to output RNA secondary structure as structural elements (e.g., <start><stem><hairpin><stem><end>)."
167
  )
168
 
169
- if __name__ == "__main__":
170
- demo.launch()
171
-
 
1
+ import os
2
+ import re
3
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  import gradio as gr
5
+ from functools import lru_cache
6
  from transformers import AutoTokenizer, AutoModelForCausalLM
 
 
 
7
 
8
+ MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
 
9
 
10
  DB_FULL = re.compile(r"^[().]+$")
11
  DB_SCAN = re.compile(r"[().]{5,}")
12
 
13
+ @lru_cache(maxsize=1)
14
+ def _load_model_and_tokenizer():
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ MODEL_ID,
19
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
20
+ device_map="auto" if device == "cuda" else None,
21
+ )
22
+ model.eval()
23
+ return tokenizer, model, device
24
+
25
+ def _generate(prompt, max_new_tokens=256, temperature=0.0):
26
+ tokenizer, model, device = _load_model_and_tokenizer()
27
+ with torch.inference_mode():
28
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
29
  outputs = model.generate(
30
  **inputs,
31
  max_new_tokens=max_new_tokens,
 
47
  return None
48
 
49
  def dotbracket_to_structural(dot_str: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
50
  n = len(dot_str)
51
  res = ["<start>"]
52
  depth = 0
53
  i = 0
54
 
55
  def append_once(tag: str):
56
+ if res[-1] != tag:
57
  res.append(tag)
58
 
59
  while i < n:
60
  c = dot_str[i]
 
61
  if c == '.':
 
62
  j = i
63
  while j < n and dot_str[j] == '.':
64
  j += 1
65
  next_char = dot_str[j] if j < n else None
 
66
  if depth == 0:
67
  label = "<external_loop>"
68
  else:
69
+ label = "<hairpin>" if next_char == ')' else "<internal_loop>"
 
 
 
 
 
 
 
70
  append_once(label)
71
  i = j
72
  continue
 
 
 
73
  if c == '(':
74
  append_once("<stem>")
75
  depth += 1
76
+ else: # ')'
77
  append_once("<stem>")
 
 
78
  depth = max(depth - 1, 0)
 
79
  i += 1
80
 
81
  res.append("<end>")
82
  return "".join(res)
83
 
84
+ def predict(seq: str):
85
  seq = (seq or "").strip().upper()
86
  if not seq or not set(seq) <= {"A","U","C","G"}:
87
  return "Please enter an RNA sequence (A/U/C/G)."
88
 
89
  n = len(seq)
90
  prompt = f"RNA: {seq}\nDot-bracket structure:"
91
+ text = _generate(prompt, max_new_tokens=n + 32, temperature=0.0)
92
 
 
93
  db = _extract_dotbracket(text, n)
94
  if db is None:
 
95
  db_chars = [c for c in text if c in "()."]
96
  db = "".join(db_chars) if len(db_chars) == n else None
97
  if db is None:
98
+ return text.strip()
99
 
100
+ return dotbracket_to_structural(db)
 
 
101
 
102
  demo = gr.Interface(
103
  fn=predict,
104
  inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)"),
105
  outputs=gr.Textbox(lines=6, label="Predicted Structural Elements"),
106
  title="RNA Structure Predictor",
107
+ description="Outputs structural-element notation: <start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>."
108
  )
109