import gradio as gr from transformers import AutoTokenizer, AutoModelForCausalLM import torch, re MODEL_ID = "llm-rna-api-rmit/rna-structure-model" # your uploaded model tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForCausalLM.from_pretrained(MODEL_ID) DB_FULL = re.compile(r"^[().]+$") DB_SCAN = re.compile(r"[().]{5,}") def _generate(prompt, max_new_tokens=512, temperature=0.0): with torch.no_grad(): inputs = tokenizer(prompt, return_tensors="pt") outputs = model.generate( **inputs, max_new_tokens=max_new_tokens, temperature=temperature, do_sample=(temperature > 0), pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def _extract_dotbracket(text, length): s = text.strip() if len(s) == length and DB_FULL.match(s): return s for m in DB_SCAN.finditer(text): cand = m.group(0) if len(cand) == length: return cand return None def predict(seq): seq = (seq or "").strip().upper() if not seq or not set(seq) <= {"A","U","C","G"}: return "Please enter an RNA sequence (A/U/C/G)." n = len(seq) prompt = f"RNA: {seq}\nDot-bracket structure:" text = _generate(prompt, max_new_tokens=n + 10, temperature=0.0) db_chars = [c for c in text if c in "()."] db = "".join(db_chars) if len(db) == n: return db for m in DB_SCAN.finditer(text): cand = m.group(0) if len(cand) == n: return cand return text.strip() demo = gr.Interface( fn=predict, inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)"), outputs=gr.Textbox(lines=6, label="Predicted Dot-Bracket Structure"), title="RNA Structure Predictor", description="Uses your fine-tuned model to output RNA secondary structure in dot-bracket notation." ) if __name__ == "__main__": demo.launch()