1-1-3-8's picture
Update app.py
09334f5 verified
import gradio as gr
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch, re
MODEL_ID = "llm-rna-api-rmit/rna-structure-model" # your uploaded model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(MODEL_ID)
DB_FULL = re.compile(r"^[().]+$")
DB_SCAN = re.compile(r"[().]{5,}")
def _generate(prompt, max_new_tokens=512, temperature=0.0):
with torch.no_grad():
inputs = tokenizer(prompt, return_tensors="pt")
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=(temperature > 0),
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def _extract_dotbracket(text, length):
s = text.strip()
if len(s) == length and DB_FULL.match(s):
return s
for m in DB_SCAN.finditer(text):
cand = m.group(0)
if len(cand) == length:
return cand
return None
def predict(seq):
seq = (seq or "").strip().upper()
if not seq or not set(seq) <= {"A","U","C","G"}:
return "Please enter an RNA sequence (A/U/C/G)."
n = len(seq)
prompt = f"RNA: {seq}\nDot-bracket structure:"
text = _generate(prompt, max_new_tokens=n + 10, temperature=0.0)
db_chars = [c for c in text if c in "()."]
db = "".join(db_chars)
if len(db) == n:
return db
for m in DB_SCAN.finditer(text):
cand = m.group(0)
if len(cand) == n:
return cand
return text.strip()
demo = gr.Interface(
fn=predict,
inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)"),
outputs=gr.Textbox(lines=6, label="Predicted Dot-Bracket Structure"),
title="RNA Structure Predictor",
description="Uses your fine-tuned model to output RNA secondary structure in dot-bracket notation."
)
if __name__ == "__main__":
demo.launch()