Spaces:
Runtime error
Runtime error
updated
Browse files
app.py
CHANGED
|
@@ -2,18 +2,27 @@ import gradio as gr
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
-
from enformer_pytorch import Enformer
|
| 6 |
from einops import rearrange
|
| 7 |
|
| 8 |
-
#
|
| 9 |
-
model =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
model.eval()
|
| 11 |
|
| 12 |
-
#
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
|
| 15 |
one_hot = np.zeros((length, 4), dtype=np.float32)
|
| 16 |
-
sequence = sequence.upper().replace("N", "A")
|
| 17 |
for i, base in enumerate(sequence[:length]):
|
| 18 |
if base in mapping:
|
| 19 |
one_hot[i, mapping[base]] = 1.0
|
|
@@ -22,31 +31,30 @@ def one_hot_encode(sequence, length=196_608):
|
|
| 22 |
# Prediction function
|
| 23 |
def predict_expression(dna_sequence):
|
| 24 |
encoded = one_hot_encode(dna_sequence)
|
| 25 |
-
input_tensor = torch.tensor(encoded).unsqueeze(0) # (1, length, 4)
|
| 26 |
-
input_tensor = rearrange(input_tensor, 'b l c -> b c l') # (1, 4, length)
|
| 27 |
-
|
| 28 |
with torch.no_grad():
|
| 29 |
output = model(input_tensor)
|
| 30 |
-
|
| 31 |
-
avg_expr = expression[0].mean(dim=0).numpy() # average across sequence positions
|
| 32 |
|
| 33 |
-
# Plot first 10
|
| 34 |
-
plt.figure(figsize=(
|
| 35 |
-
plt.bar(range(10),
|
| 36 |
plt.xticks(range(10), [f"Tissue {i}" for i in range(10)])
|
| 37 |
-
plt.
|
| 38 |
-
plt.
|
| 39 |
plt.tight_layout()
|
| 40 |
|
| 41 |
return plt.gcf()
|
| 42 |
|
| 43 |
-
# Gradio
|
| 44 |
demo = gr.Interface(
|
| 45 |
fn=predict_expression,
|
| 46 |
-
inputs=gr.Textbox(lines=
|
| 47 |
-
outputs=gr.Plot(label="Predicted
|
| 48 |
-
title="Gene Expression
|
| 49 |
-
description="Paste a DNA sequence
|
| 50 |
)
|
| 51 |
|
| 52 |
demo.launch()
|
|
|
|
| 2 |
import torch
|
| 3 |
import numpy as np
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
+
from enformer_pytorch import Enformer
|
| 6 |
from einops import rearrange
|
| 7 |
|
| 8 |
+
# Initialize Enformer with correct architecture (based on EleutherAI/enformer-191k)
|
| 9 |
+
model = Enformer(
|
| 10 |
+
num_channels=1536,
|
| 11 |
+
num_classes=5313,
|
| 12 |
+
target_length=896,
|
| 13 |
+
depth=11,
|
| 14 |
+
heads=8
|
| 15 |
+
)
|
| 16 |
model.eval()
|
| 17 |
|
| 18 |
+
# Optionally load pretrained weights if available locally or upload to HF Spaces manually
|
| 19 |
+
# model.load_state_dict(torch.load("enformer-191k.pth")) # optional for offline Spaces
|
| 20 |
+
|
| 21 |
+
# Helper function to one-hot encode DNA
|
| 22 |
+
def one_hot_encode(sequence, length=196608):
|
| 23 |
mapping = {'A': 0, 'C': 1, 'G': 2, 'T': 3}
|
| 24 |
one_hot = np.zeros((length, 4), dtype=np.float32)
|
| 25 |
+
sequence = sequence.upper().replace("N", "A")
|
| 26 |
for i, base in enumerate(sequence[:length]):
|
| 27 |
if base in mapping:
|
| 28 |
one_hot[i, mapping[base]] = 1.0
|
|
|
|
| 31 |
# Prediction function
|
| 32 |
def predict_expression(dna_sequence):
|
| 33 |
encoded = one_hot_encode(dna_sequence)
|
| 34 |
+
input_tensor = torch.tensor(encoded).unsqueeze(0) # shape: (1, length, 4)
|
| 35 |
+
input_tensor = rearrange(input_tensor, 'b l c -> b c l') # shape: (1, 4, length)
|
| 36 |
+
|
| 37 |
with torch.no_grad():
|
| 38 |
output = model(input_tensor)
|
| 39 |
+
avg_expression = output[0].mean(dim=0).numpy() # (5313,)
|
|
|
|
| 40 |
|
| 41 |
+
# Plot first 10 expression predictions
|
| 42 |
+
plt.figure(figsize=(10, 4))
|
| 43 |
+
plt.bar(range(10), avg_expression[:10])
|
| 44 |
plt.xticks(range(10), [f"Tissue {i}" for i in range(10)])
|
| 45 |
+
plt.title("Predicted Gene Expression")
|
| 46 |
+
plt.ylabel("Signal")
|
| 47 |
plt.tight_layout()
|
| 48 |
|
| 49 |
return plt.gcf()
|
| 50 |
|
| 51 |
+
# Gradio app
|
| 52 |
demo = gr.Interface(
|
| 53 |
fn=predict_expression,
|
| 54 |
+
inputs=gr.Textbox(lines=6, label="Paste DNA Sequence (200k bp)"),
|
| 55 |
+
outputs=gr.Plot(label="Predicted Expression Tracks (first 10 tissues)"),
|
| 56 |
+
title="Gene Expression Prediction with Enformer",
|
| 57 |
+
description="Paste a 200kb DNA sequence and see predicted expression levels using Enformer."
|
| 58 |
)
|
| 59 |
|
| 60 |
demo.launch()
|