test_model / app.py
1-1-3-8's picture
Update app.py
dd9832d verified
# app.py
import os
import torch
import traceback
from functools import lru_cache
from typing import Tuple
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
# -------------------------
# Config
# -------------------------
MODEL_ID = os.getenv("MODEL_ID", "llm-rna-api-rmit/rna-structure-model")
# Optionally set HF_TOKEN if using private HF repo:
HF_TOKEN = os.getenv("HF_TOKEN", None)
# Global placeholders (populated by init_model)
TOKENIZER = None
MODEL = None
# -------------------------
# Utility helpers
# -------------------------
@lru_cache(maxsize=1)
def _load_model_and_tokenizer() -> Tuple[AutoTokenizer, AutoModelForCausalLM]:
"""
Load tokenizer + model once. Use float16 if CUDA present.
"""
device = "cuda" if torch.cuda.is_available() else "cpu"
use_auth = {"use_auth_token": HF_TOKEN} if HF_TOKEN else {}
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, **use_auth)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None,
**use_auth
)
return tokenizer, model
def _char_token_id(tokenizer, ch: str) -> int:
# Prefer exact single-char token if it exists
ids = tokenizer.encode(ch, add_special_tokens=False)
for tid in ids:
if tokenizer.decode([tid]) == ch:
return tid
# Fallback: scan vocab (conservative)
vocab_size = getattr(tokenizer, "vocab_size", None) or len(tokenizer)
for tid in range(vocab_size):
try:
if tokenizer.decode([tid]) == ch:
return tid
except Exception:
continue
raise ValueError(f"Could not find token id for {ch!r}")
def _can_pair(a, b, allow_gu=True):
if (a, b) in [("A","U"),("U","A"),("G","C"),("C","G")]:
return True
if allow_gu and (a, b) in [("G","U"),("U","G")]:
return True
return False
def _precompute_can_open(seq, min_loop=3, allow_gu=True):
n = len(seq)
can = [False] * n
for i in range(n):
for j in range(i + min_loop + 1, n):
if _can_pair(seq[i], seq[j], allow_gu):
can[i] = True
break
return can
# -------------------------
# Constrained Logits Processor
# -------------------------
class BalancedParenProcessor(LogitsProcessor):
"""
Restricts next token to one of '(' or ')' or '.', tracking balance and remaining positions.
"""
def __init__(self, lp_id, rp_id, dot_id, total_len, can_open,
dot_bias=0.0, paren_penalty=0.0, window=5):
self.lp_id, self.rp_id, self.dot_id = lp_id, rp_id, dot_id
self.total_len = total_len
self.step = 0
self.depth = 0
self.history = []
self.can_open = can_open
self.dot_bias = dot_bias
self.paren_penalty = paren_penalty
self.window = window
def __call__(self, input_ids, scores):
mask = torch.full_like(scores, float("-inf"))
remaining = self.total_len - self.step
allowed = []
must_close = (remaining == self.depth and self.depth > 0)
pos = self.step
if must_close:
allowed = [self.rp_id]
else:
if self.depth > 0:
allowed.append(self.rp_id)
if remaining - 2 >= self.depth and pos < len(self.can_open) and self.can_open[pos]:
allowed.append(self.lp_id)
allowed.append(self.dot_id)
mask[:, allowed] = 0.0
scores = scores + mask
if self.dot_bias != 0.0:
scores[:, self.dot_id] += self.dot_bias
if self.paren_penalty and len(self.history) >= self.window and all(
t in (self.lp_id, self.rp_id) for t in self.history[-self.window:]
):
scores[:, self.lp_id] -= self.paren_penalty
scores[:, self.rp_id] -= self.paren_penalty
return scores
def update(self, tok):
if tok == self.lp_id:
self.depth += 1
elif tok == self.rp_id:
self.depth = max(0, self.depth - 1)
self.history.append(tok)
self.step += 1
# -------------------------
# Sampling helpers
# -------------------------
def _top_p_sample(logits, top_p=0.9, temperature=0.8):
logits = logits / temperature
probs = torch.softmax(logits, dim=-1)
sorted_probs, sorted_idx = torch.sort(probs, descending=True)
cumsum = torch.cumsum(sorted_probs, dim=-1)
mask = cumsum > top_p
mask[..., 0] = False
sorted_probs[mask] = 0
sorted_probs = sorted_probs / (sorted_probs.sum(dim=-1, keepdim=True) + 1e-12)
idx = torch.multinomial(sorted_probs, 1)
return sorted_idx.gather(-1, idx).squeeze(-1)
# -------------------------
# Core generation (uses loaded model/tokenizer)
# -------------------------
def _generate_db(seq: str, top_p=0.8, temperature=0.7, min_loop=2, greedy=False) -> str:
if TOKENIZER is None or MODEL is None:
raise RuntimeError("Model not initialized — call init_model() first.")
tok = TOKENIZER
model = MODEL
n = len(seq)
prompt = f"RNA: {seq}\nDot-bracket (exactly {n} characters using only '(' ')' '.'):\n"
lp = _char_token_id(tok, "(")
rp = _char_token_id(tok, ")")
dot = _char_token_id(tok, ".")
can = _precompute_can_open(seq, min_loop=min_loop, allow_gu=True)
proc = BalancedParenProcessor(lp, rp, dot, n, can, dot_bias=0.0, paren_penalty=0.0)
procs = LogitsProcessorList([proc])
inputs = tok(prompt, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
cur = inputs["input_ids"]
generated = []
with torch.no_grad():
for _ in range(n):
out = model(cur)
logits = out.logits[:, -1, :]
for p in procs:
logits = p(cur, logits)
if greedy:
# Greedy: pick highest allowed token
next_id = torch.argmax(torch.softmax(logits, dim=-1), dim=-1)
else:
next_id = _top_p_sample(logits, top_p=top_p, temperature=temperature)
tokid = int(next_id.item()) if isinstance(next_id, torch.Tensor) else int(next_id)
generated.append(tokid)
proc.update(tokid)
cur = torch.cat([cur, next_id.view(1, 1).to(cur.device)], dim=1)
text = tok.decode(generated, skip_special_tokens=True)
db = "".join(c for c in text if c in "().")[:n]
if len(db) != n:
db = (db + "." * n)[:n]
return db
# -------------------------
# Structural translation
# -------------------------
def dotbracket_to_structural(dot_str: str) -> str:
if not dot_str or not isinstance(dot_str, str):
return "<start><external_loop><end>"
res = ["<start>"]; depth = 0; i = 0; n = len(dot_str)
def add(tag):
if res[-1] != tag:
res.append(tag)
while i < n:
c = dot_str[i]
if c == ".":
j = i
while j < n and dot_str[j] == ".":
j += 1
nextc = dot_str[j] if j < n else None
tag = "<external_loop>" if depth == 0 else ("<hairpin>" if nextc == ")" else "<internal_loop>")
add(tag); i = j; continue
if c == "(":
add("<stem>"); depth += 1
else: # ')'
add("<stem>"); depth = max(0, depth - 1)
i += 1
res.append("<end>")
return "".join(res)
# -------------------------
# Public API (evaluation-friendly)
# -------------------------
def init_model():
"""
Initialize tokenizer & model (call once). Safe to call multiple times.
"""
global TOKENIZER, MODEL
TOKENIZER, MODEL = _load_model_and_tokenizer()
# Try moving model to device if appropriate (some HF device_map configs don't like .to())
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
MODEL.to(device)
except Exception:
pass
MODEL.eval()
print("Model initialized.")
def predict_structure(seq: str,
top_p: float = 0.8,
temperature: float = 0.7,
greedy: bool = False,
return_dotbracket: bool = False) -> str:
"""
The evaluation harness expects a function like this. It returns the
structural-element string by default. If return_dotbracket=True, returns dot-bracket.
"""
try:
seq = (seq or "").strip().upper()
if not seq or not set(seq) <= {"A", "U", "C", "G"}:
return "Please enter an RNA sequence (A/U/C/G)."
db = _generate_db(seq, top_p=top_p, temperature=temperature, greedy=greedy)
if return_dotbracket:
return db
return dotbracket_to_structural(db)
except Exception as e:
traceback.print_exc()
return f"ERROR: {type(e).__name__}: {e}"
# -------------------------
# Gradio UI for Hugging Face Space
# -------------------------
# Initialize model at import to speed first prediction on Spaces (cached)
try:
init_model()
except Exception:
# Fail gracefully in spaces if model can't be loaded at import-time; it will attempt again when used
traceback.print_exc()
def _ui_predict(seq, top_p, temp, greedy, show_db):
# Wrap to be Gradio-friendly and show both outputs if requested
pred_struct = predict_structure(seq, top_p=top_p, temperature=temp, greedy=greedy, return_dotbracket=False)
if isinstance(pred_struct, str) and pred_struct.startswith("ERROR:"):
return pred_struct, ""
db = predict_structure(seq, top_p=top_p, temperature=temp, greedy=greedy, return_dotbracket=True)
if show_db:
return pred_struct, db
else:
return pred_struct, ""
with gr.Blocks(title="RNA Structure Predictor (Constrained Generation)") as demo:
gr.Markdown(
"""
# RNA Structure Predictor
Generates a dot-bracket structure constrained to `(`, `)` and `.` and converts it to structural elements:
`<start>, <stem>, <hairpin>, <internal_loop>, <external_loop>, <end>`.
"""
)
with gr.Row():
seq_in = gr.Textbox(lines=2, label="RNA sequence (A/U/C/G)", value="GGGAAUCC")
with gr.Column(scale=1):
top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.01, label="top_p")
temp = gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.01, label="temperature")
greedy = gr.Checkbox(value=False, label="Greedy (disable sampling)")
show_db = gr.Checkbox(value=True, label="Show dot-bracket output")
run_btn = gr.Button("Predict")
out_struct = gr.Textbox(lines=6, label="Predicted structural elements")
out_db = gr.Textbox(lines=3, label="Dot-bracket (optional)")
run_btn.click(fn=_ui_predict,
inputs=[seq_in, top_p, temp, greedy, show_db],
outputs=[out_struct, out_db])
gr.Markdown(
"Notes: The model uses a constrained logits processor to ensure balanced parentheses and valid dot-bracket length. "
"You can tune `top_p`/`temperature` or enable greedy for deterministic output."
)
if __name__ == "__main__":
demo.launch()