# app.py import os import torch import traceback from functools import lru_cache from typing import Tuple import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList # ------------------------- # Config # ------------------------- MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model") # Optionally set HF_TOKEN if using private HF repo: HF_TOKEN = os.getenv("HF_TOKEN", None) # Global placeholders (populated by init_model) TOKENIZER = None MODEL = None # ------------------------- # Utility helpers # ------------------------- @lru_cache(maxsize=1) def _load_model_and_tokenizer() -> Tuple[AutoTokenizer, AutoModelForCausalLM]: """ Load tokenizer + model once. Use float16 if CUDA present. """ device = "cuda" if torch.cuda.is_available() else "cpu" use_auth = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {} tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, **use_auth) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16 if device == "cuda" else torch.float32, device_map="auto" if device == "cuda" else None, **use_auth ) return tokenizer, model def _char_token_id(tokenizer, ch: str) -> int: # Prefer exact single-char token if it exists ids = tokenizer.encode(ch, add_special_tokens=False) for tid in ids: if tokenizer.decode([tid]) == ch: return tid # Fallback: scan vocab (conservative) vocab_size = getattr(tokenizer, "vocab_size", None) or len(tokenizer) for tid in range(vocab_size): try: if tokenizer.decode([tid]) == ch: return tid except Exception: continue raise ValueError(f"Could not find token id for {ch!r}") def _can_pair(a, b, allow_gu=True): if (a, b) in [("A","U"),("U","A"),("G","C"),("C","G")]: return True if allow_gu and (a, b) in [("G","U"),("U","G")]: return True return False def _precompute_can_open(seq, min_loop=3, allow_gu=True): n = len(seq) can = [False] * n for i in range(n): for j in range(i + min_loop + 1, n): if _can_pair(seq[i], seq[j], allow_gu): can[i] = True break return can # ------------------------- # Constrained Logits Processor # ------------------------- class BalancedParenProcessor(LogitsProcessor): """ Restricts next token to one of '(' or ')' or '.', tracking balance and remaining positions. """ def __init__(self, lp_id, rp_id, dot_id, total_len, can_open, dot_bias=0.0, paren_penalty=0.0, window=5): self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id self.total_len = total_len self.step = 0 self.depth = 0 self.history = [] self.can_open = can_open self.dot_bias = dot_bias self.paren_penalty = paren_penalty self.window = window def __call__(self, input_ids, scores): mask = torch.full_like(scores, float("-inf")) remaining = self.total_len - self.step allowed = [] must_close = (remaining == self.depth and self.depth > 0) pos = self.step if must_close: allowed = [self.rp_id] else: if self.depth > 0: allowed.append(self.rp_id) if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]: allowed.append(self.lp_id) allowed.append(self.dot_id) mask[:, allowed] = 0.0 scores = scores + mask if self.dot_bias != 0.0: scores[:, self.dot_id] += self.dot_bias if self.paren_penalty and len(self.history) >= self.window and all( t in (self.lp_id, self.rp_id) for t in self.history[-self.window:] ): scores[:, self.lp_id] -= self.paren_penalty scores[:, self.rp_id] -= self.paren_penalty return scores def update(self, tok): if tok == self.lp_id: self.depth += 1 elif tok == self.rp_id: self.depth = max(0, self.depth - 1) self.history.append(tok) self.step += 1 # ------------------------- # Sampling helpers # ------------------------- def _top_p_sample(logits, top_p=0.9, temperature=0.8): logits = logits / temperature probs = torch.softmax(logits, dim=-1) sorted_probs, sorted_idx = torch.sort(probs, descending=True) cumsum = torch.cumsum(sorted_probs, dim=-1) mask = cumsum > top_p mask[..., 0] = False sorted_probs[mask] = 0 sorted_probs = sorted_probs / (sorted_probs.sum(dim=-1, keepdim=True) + 1e-12) idx = torch.multinomial(sorted_probs, 1) return sorted_idx.gather(-1, idx).squeeze(-1) # ------------------------- # Core generation (uses loaded model/tokenizer) # ------------------------- def _generate_db(seq: str, top_p=0.8, temperature=0.7, min_loop=2, greedy=False) -> str: if TOKENIZER is None or MODEL is None: raise RuntimeError("Model not initialized — call init_model() first.") tok = TOKENIZER model = MODEL n = len(seq) prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n" lp = _char_token_id(tok, "(") rp = _char_token_id(tok, ")") dot = _char_token_id(tok, ".") can = _precompute_can_open(seq, min_loop=min_loop, allow_gu=True) proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0) procs = LogitsProcessorList([proc]) inputs = tok(prompt, return_tensors="pt") inputs = {k: v.to(model.device) for k, v in inputs.items()} cur = inputs["input_ids"] generated = [] with torch.no_grad(): for _ in range(n): out = model(cur) logits = out.logits[:, -1, :] for p in procs: logits = p(cur, logits) if greedy: # Greedy: pick highest allowed token next_id = torch.argmax(torch.softmax(logits, dim=-1), dim=-1) else: next_id = _top_p_sample(logits, top_p=top_p, temperature=temperature) tokid = int(next_id.item()) if isinstance(next_id, torch.Tensor) else int(next_id) generated.append(tokid) proc.update(tokid) cur = torch.cat([cur, next_id.view(1, 1).to(cur.device)], dim=1) text = tok.decode(generated, skip_special_tokens=True) db = "".join(c for c in text if c in "().")[:n] if len(db) != n: db = (db + "." * n)[:n] return db # ------------------------- # Structural translation # ------------------------- def dotbracket_to_structural(dot_str: str) -> str: if not dot_str or not isinstance(dot_str, str): return "" res = [""]; depth = 0; i = 0; n = len(dot_str) def add(tag): if res[-1] != tag: res.append(tag) while i < n: c = dot_str[i] if c == ".": j = i while j < n and dot_str[j] == ".": j += 1 nextc = dot_str[j] if j < n else None tag = "" if depth == 0 else ("" if nextc == ")" else "") add(tag); i = j; continue if c == "(": add(""); depth += 1 else: # ')' add(""); depth = max(0, depth - 1) i += 1 res.append("") return "".join(res) # ------------------------- # Public API (evaluation-friendly) # ------------------------- def init_model(): """ Initialize tokenizer & model (call once). Safe to call multiple times. """ global TOKENIZER, MODEL TOKENIZER, MODEL = _load_model_and_tokenizer() # Try moving model to device if appropriate (some HF device_map configs don't like .to()) try: device = "cuda" if torch.cuda.is_available() else "cpu" MODEL.to(device) except Exception: pass MODEL.eval() print("Model initialized.") def predict_structure(seq: str, top_p: float = 0.8, temperature: float = 0.7, greedy: bool = False, return_dotbracket: bool = False) -> str: """ The evaluation harness expects a function like this. It returns the structural-element string by default. If return_dotbracket=True, returns dot-bracket. """ try: seq = (seq or "").strip().upper() if not seq or not set(seq) <= {"A", "U", "C", "G"}: return "Please enter an RNA sequence (A/U/C/G)." db = _generate_db(seq, top_p=top_p, temperature=temperature, greedy=greedy) if return_dotbracket: return db return dotbracket_to_structural(db) except Exception as e: traceback.print_exc() return f"ERROR: {type(e).__name__}: {e}" # ------------------------- # Gradio UI for Hugging Face Space # ------------------------- # Initialize model at import to speed first prediction on Spaces (cached) try: init_model() except Exception: # Fail gracefully in spaces if model can't be loaded at import-time; it will attempt again when used traceback.print_exc() def _ui_predict(seq, top_p, temp, greedy, show_db): # Wrap to be Gradio-friendly and show both outputs if requested pred_struct = predict_structure(seq, top_p=top_p, temperature=temp, greedy=greedy, return_dotbracket=False) if isinstance(pred_struct, str) and pred_struct.startswith("ERROR:"): return pred_struct, "" db = predict_structure(seq, top_p=top_p, temperature=temp, greedy=greedy, return_dotbracket=True) if show_db: return pred_struct, db else: return pred_struct, "" with gr.Blocks(title="RNA Structure Predictor (Constrained Generation)") as demo: gr.Markdown( """ # RNA Structure Predictor Generates a dot-bracket structure constrained to `(`, `)` and `.` and converts it to structural elements: `, , , , , `. """ ) with gr.Row(): seq_in = gr.Textbox(lines=2, label="RNA sequence (A/U/C/G)", value="GGGAAUCC") with gr.Column(scale=1): top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.01, label="top_p") temp = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.01, label="temperature") greedy = gr.Checkbox(value=False, label="Greedy (disable sampling)") show_db = gr.Checkbox(value=True, label="Show dot-bracket output") run_btn = gr.Button("Predict") out_struct = gr.Textbox(lines=6, label="Predicted structural elements") out_db = gr.Textbox(lines=3, label="Dot-bracket (optional)") run_btn.click(fn=_ui_predict, inputs=[seq_in, top_p, temp, greedy, show_db], outputs=[out_struct, out_db]) gr.Markdown( "Notes: The model uses a constrained logits processor to ensure balanced parentheses and valid dot-bracket length. " "You can tune `top_p`/`temperature` or enable greedy for deterministic output." ) if __name__ == "__main__": demo.launch()