Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from enformer_pytorch import Enformer | |
| from einops import rearrange | |
| # Initialize Enformer with correct architecture (based on EleutherAI/enformer-191k) | |
| model = Enformer( | |
| num_channels=1536, | |
| num_classes=5313, | |
| target_length=896, | |
| depth=11, | |
| heads=8 | |
| ) | |
| model.eval() | |
| # Optionally load pretrained weights if available locally or upload to HF Spaces manually | |
| # model.load_state_dict(torch.load("enformer-191k.pth")) # optional for offline Spaces | |
| # Helper function to one-hot encode DNA | |
| def one_hot_encode(sequence, length=196608): | |
| mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3} | |
| one_hot = np.zeros((length, 4), dtype=np.float32) | |
| sequence = sequence.upper().replace("N", "A") | |
| for i, base in enumerate(sequence[:length]): | |
| if base in mapping: | |
| one_hot[i, mapping[base]] = 1.0 | |
| return one_hot | |
| # Prediction function | |
| def predict_expression(dna_sequence): | |
| encoded = one_hot_encode(dna_sequence) | |
| input_tensor = torch.tensor(encoded).unsqueeze(0) # shape: (1, length, 4) | |
| input_tensor = rearrange(input_tensor, 'b l c -> b c l') # shape: (1, 4, length) | |
| with torch.no_grad(): | |
| output = model(input_tensor) | |
| avg_expression = output[0].mean(dim=0).numpy() # (5313,) | |
| # Plot first 10 expression predictions | |
| plt.figure(figsize=(10, 4)) | |
| plt.bar(range(10), avg_expression[:10]) | |
| plt.xticks(range(10), [f"Tissue {i}" for i in range(10)]) | |
| plt.title("Predicted Gene Expression") | |
| plt.ylabel("Signal") | |
| plt.tight_layout() | |
| return plt.gcf() | |
| # Gradio app | |
| demo = gr.Interface( | |
| fn=predict_expression, | |
| inputs=gr.Textbox(lines=6, label="Paste DNA Sequence (200k bp)"), | |
| outputs=gr.Plot(label="Predicted Expression Tracks (first 10 tissues)"), | |
| title="Gene Expression Prediction with Enformer", | |
| description="Paste a 200kb DNA sequence and see predicted expression levels using Enformer." | |
| ) | |
| demo.launch() |