Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch, re | |
| MODEL_ID = "llm-rna-api-rmit/rna-structure-model" # your uploaded model | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_ID) | |
| DB_FULL = re.compile(r"^[().]+$") | |
| DB_SCAN = re.compile(r"[().]{5,}") | |
| def _generate(prompt, max_new_tokens=512, temperature=0.0): | |
| with torch.no_grad(): | |
| inputs = tokenizer(prompt, return_tensors="pt") | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=(temperature > 0), | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| ) | |
| return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| def _extract_dotbracket(text, length): | |
| s = text.strip() | |
| if len(s) == length and DB_FULL.match(s): | |
| return s | |
| for m in DB_SCAN.finditer(text): | |
| cand = m.group(0) | |
| if len(cand) == length: | |
| return cand | |
| return None | |
| def predict(seq): | |
| seq = (seq or "").strip().upper() | |
| if not seq or not set(seq) <= {"A","U","C","G"}: | |
| return "Please enter an RNA sequence (A/U/C/G)." | |
| n = len(seq) | |
| prompt = f"RNA: {seq}\nDot-bracket structure:" | |
| text = _generate(prompt, max_new_tokens=n + 10, temperature=0.0) | |
| db_chars = [c for c in text if c in "()."] | |
| db = "".join(db_chars) | |
| if len(db) == n: | |
| return db | |
| for m in DB_SCAN.finditer(text): | |
| cand = m.group(0) | |
| if len(cand) == n: | |
| return cand | |
| return text.strip() | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Textbox(lines=4, label="RNA Sequence (A/U/C/G)"), | |
| outputs=gr.Textbox(lines=6, label="Predicted Dot-Bracket Structure"), | |
| title="RNA Structure Predictor", | |
| description="Uses your fine-tuned model to output RNA secondary structure in dot-bracket notation." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |