Spaces:
Sleeping
Sleeping
| # app.py | |
| import os | |
| import torch | |
| import traceback | |
| from functools import lru_cache | |
| from typing import Tuple | |
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList | |
| # ------------------------- | |
| # Config | |
| # ------------------------- | |
| MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model") | |
| # Optionally set HF_TOKEN if using private HF repo: | |
| HF_TOKEN = os.getenv("HF_TOKEN", None) | |
| # Global placeholders (populated by init_model) | |
| TOKENIZER = None | |
| MODEL = None | |
| # ------------------------- | |
| # Utility helpers | |
| # ------------------------- | |
| def _load_model_and_tokenizer() -> Tuple[AutoTokenizer, AutoModelForCausalLM]: | |
| """ | |
| Load tokenizer + model once. Use float16 if CUDA present. | |
| """ | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| use_auth = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {} | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, **use_auth) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=torch.float16 if device == "cuda" else torch.float32, | |
| device_map="auto" if device == "cuda" else None, | |
| **use_auth | |
| ) | |
| return tokenizer, model | |
| def _char_token_id(tokenizer, ch: str) -> int: | |
| # Prefer exact single-char token if it exists | |
| ids = tokenizer.encode(ch, add_special_tokens=False) | |
| for tid in ids: | |
| if tokenizer.decode([tid]) == ch: | |
| return tid | |
| # Fallback: scan vocab (conservative) | |
| vocab_size = getattr(tokenizer, "vocab_size", None) or len(tokenizer) | |
| for tid in range(vocab_size): | |
| try: | |
| if tokenizer.decode([tid]) == ch: | |
| return tid | |
| except Exception: | |
| continue | |
| raise ValueError(f"Could not find token id for {ch!r}") | |
| def _can_pair(a, b, allow_gu=True): | |
| if (a, b) in [("A","U"),("U","A"),("G","C"),("C","G")]: | |
| return True | |
| if allow_gu and (a, b) in [("G","U"),("U","G")]: | |
| return True | |
| return False | |
| def _precompute_can_open(seq, min_loop=3, allow_gu=True): | |
| n = len(seq) | |
| can = [False] * n | |
| for i in range(n): | |
| for j in range(i + min_loop + 1, n): | |
| if _can_pair(seq[i], seq[j], allow_gu): | |
| can[i] = True | |
| break | |
| return can | |
| # ------------------------- | |
| # Constrained Logits Processor | |
| # ------------------------- | |
| class BalancedParenProcessor(LogitsProcessor): | |
| """ | |
| Restricts next token to one of '(' or ')' or '.', tracking balance and remaining positions. | |
| """ | |
| def __init__(self, lp_id, rp_id, dot_id, total_len, can_open, | |
| dot_bias=0.0, paren_penalty=0.0, window=5): | |
| self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id | |
| self.total_len = total_len | |
| self.step = 0 | |
| self.depth = 0 | |
| self.history = [] | |
| self.can_open = can_open | |
| self.dot_bias = dot_bias | |
| self.paren_penalty = paren_penalty | |
| self.window = window | |
| def __call__(self, input_ids, scores): | |
| mask = torch.full_like(scores, float("-inf")) | |
| remaining = self.total_len - self.step | |
| allowed = [] | |
| must_close = (remaining == self.depth and self.depth > 0) | |
| pos = self.step | |
| if must_close: | |
| allowed = [self.rp_id] | |
| else: | |
| if self.depth > 0: | |
| allowed.append(self.rp_id) | |
| if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]: | |
| allowed.append(self.lp_id) | |
| allowed.append(self.dot_id) | |
| mask[:, allowed] = 0.0 | |
| scores = scores + mask | |
| if self.dot_bias != 0.0: | |
| scores[:, self.dot_id] += self.dot_bias | |
| if self.paren_penalty and len(self.history) >= self.window and all( | |
| t in (self.lp_id, self.rp_id) for t in self.history[-self.window:] | |
| ): | |
| scores[:, self.lp_id] -= self.paren_penalty | |
| scores[:, self.rp_id] -= self.paren_penalty | |
| return scores | |
| def update(self, tok): | |
| if tok == self.lp_id: | |
| self.depth += 1 | |
| elif tok == self.rp_id: | |
| self.depth = max(0, self.depth - 1) | |
| self.history.append(tok) | |
| self.step += 1 | |
| # ------------------------- | |
| # Sampling helpers | |
| # ------------------------- | |
| def _top_p_sample(logits, top_p=0.9, temperature=0.8): | |
| logits = logits / temperature | |
| probs = torch.softmax(logits, dim=-1) | |
| sorted_probs, sorted_idx = torch.sort(probs, descending=True) | |
| cumsum = torch.cumsum(sorted_probs, dim=-1) | |
| mask = cumsum > top_p | |
| mask[..., 0] = False | |
| sorted_probs[mask] = 0 | |
| sorted_probs = sorted_probs / (sorted_probs.sum(dim=-1, keepdim=True) + 1e-12) | |
| idx = torch.multinomial(sorted_probs, 1) | |
| return sorted_idx.gather(-1, idx).squeeze(-1) | |
| # ------------------------- | |
| # Core generation (uses loaded model/tokenizer) | |
| # ------------------------- | |
| def _generate_db(seq: str, top_p=0.8, temperature=0.7, min_loop=2, greedy=False) -> str: | |
| if TOKENIZER is None or MODEL is None: | |
| raise RuntimeError("Model not initialized — call init_model() first.") | |
| tok = TOKENIZER | |
| model = MODEL | |
| n = len(seq) | |
| prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n" | |
| lp = _char_token_id(tok, "(") | |
| rp = _char_token_id(tok, ")") | |
| dot = _char_token_id(tok, ".") | |
| can = _precompute_can_open(seq, min_loop=min_loop, allow_gu=True) | |
| proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0) | |
| procs = LogitsProcessorList([proc]) | |
| inputs = tok(prompt, return_tensors="pt") | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| cur = inputs["input_ids"] | |
| generated = [] | |
| with torch.no_grad(): | |
| for _ in range(n): | |
| out = model(cur) | |
| logits = out.logits[:, -1, :] | |
| for p in procs: | |
| logits = p(cur, logits) | |
| if greedy: | |
| # Greedy: pick highest allowed token | |
| next_id = torch.argmax(torch.softmax(logits, dim=-1), dim=-1) | |
| else: | |
| next_id = _top_p_sample(logits, top_p=top_p, temperature=temperature) | |
| tokid = int(next_id.item()) if isinstance(next_id, torch.Tensor) else int(next_id) | |
| generated.append(tokid) | |
| proc.update(tokid) | |
| cur = torch.cat([cur, next_id.view(1, 1).to(cur.device)], dim=1) | |
| text = tok.decode(generated, skip_special_tokens=True) | |
| db = "".join(c for c in text if c in "().")[:n] | |
| if len(db) != n: | |
| db = (db + "." * n)[:n] | |
| return db | |
| # ------------------------- | |
| # Structural translation | |
| # ------------------------- | |
| def dotbracket_to_structural(dot_str: str) -> str: | |
| if not dot_str or not isinstance(dot_str, str): | |
| return "<start><external_loop><end>" | |
| res = ["<start>"]; depth = 0; i = 0; n = len(dot_str) | |
| def add(tag): | |
| if res[-1] != tag: | |
| res.append(tag) | |
| while i < n: | |
| c = dot_str[i] | |
| if c == ".": | |
| j = i | |
| while j < n and dot_str[j] == ".": | |
| j += 1 | |
| nextc = dot_str[j] if j < n else None | |
| tag = "<external_loop>" if depth == 0 else ("<hairpin>" if nextc == ")" else "<internal_loop>") | |
| add(tag); i = j; continue | |
| if c == "(": | |
| add("<stem>"); depth += 1 | |
| else: # ')' | |
| add("<stem>"); depth = max(0, depth - 1) | |
| i += 1 | |
| res.append("<end>") | |
| return "".join(res) | |
| # ------------------------- | |
| # Public API (evaluation-friendly) | |
| # ------------------------- | |
| def init_model(): | |
| """ | |
| Initialize tokenizer & model (call once). Safe to call multiple times. | |
| """ | |
| global TOKENIZER, MODEL | |
| TOKENIZER, MODEL = _load_model_and_tokenizer() | |
| # Try moving model to device if appropriate (some HF device_map configs don't like .to()) | |
| try: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| MODEL.to(device) | |
| except Exception: | |
| pass | |
| MODEL.eval() | |
| print("Model initialized.") | |
| def predict_structure(seq: str, | |
| top_p: float = 0.8, | |
| temperature: float = 0.7, | |
| greedy: bool = False, | |
| return_dotbracket: bool = False) -> str: | |
| """ | |
| The evaluation harness expects a function like this. It returns the | |
| structural-element string by default. If return_dotbracket=True, returns dot-bracket. | |
| """ | |
| try: | |
| seq = (seq or "").strip().upper() | |
| if not seq or not set(seq) <= {"A", "U", "C", "G"}: | |
| return "Please enter an RNA sequence (A/U/C/G)." | |
| db = _generate_db(seq, top_p=top_p, temperature=temperature, greedy=greedy) | |
| if return_dotbracket: | |
| return db | |
| return dotbracket_to_structural(db) | |
| except Exception as e: | |
| traceback.print_exc() | |
| return f"ERROR: {type(e).__name__}: {e}" | |
| # ------------------------- | |
| # Gradio UI for Hugging Face Space | |
| # ------------------------- | |
| # Initialize model at import to speed first prediction on Spaces (cached) | |
| try: | |
| init_model() | |
| except Exception: | |
| # Fail gracefully in spaces if model can't be loaded at import-time; it will attempt again when used | |
| traceback.print_exc() | |
| def _ui_predict(seq, top_p, temp, greedy, show_db): | |
| # Wrap to be Gradio-friendly and show both outputs if requested | |
| pred_struct = predict_structure(seq, top_p=top_p, temperature=temp, greedy=greedy, return_dotbracket=False) | |
| if isinstance(pred_struct, str) and pred_struct.startswith("ERROR:"): | |
| return pred_struct, "" | |
| db = predict_structure(seq, top_p=top_p, temperature=temp, greedy=greedy, return_dotbracket=True) | |
| if show_db: | |
| return pred_struct, db | |
| else: | |
| return pred_struct, "" | |
| with gr.Blocks(title="RNA Structure Predictor (Constrained Generation)") as demo: | |
| gr.Markdown( | |
| """ | |
| # RNA Structure Predictor | |
| Generates a dot-bracket structure constrained to `(`, `)` and `.` and converts it to structural elements: | |
| `<start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>`. | |
| """ | |
| ) | |
| with gr.Row(): | |
| seq_in = gr.Textbox(lines=2, label="RNA sequence (A/U/C/G)", value="GGGAAUCC") | |
| with gr.Column(scale=1): | |
| top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.01, label="top_p") | |
| temp = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.01, label="temperature") | |
| greedy = gr.Checkbox(value=False, label="Greedy (disable sampling)") | |
| show_db = gr.Checkbox(value=True, label="Show dot-bracket output") | |
| run_btn = gr.Button("Predict") | |
| out_struct = gr.Textbox(lines=6, label="Predicted structural elements") | |
| out_db = gr.Textbox(lines=3, label="Dot-bracket (optional)") | |
| run_btn.click(fn=_ui_predict, | |
| inputs=[seq_in, top_p, temp, greedy, show_db], | |
| outputs=[out_struct, out_db]) | |
| gr.Markdown( | |
| "Notes: The model uses a constrained logits processor to ensure balanced parentheses and valid dot-bracket length. " | |
| "You can tune `top_p`/`temperature` or enable greedy for deterministic output." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |