brocks1234's picture
Update handler.py
efd9157 verified
import torch
import numpy as np
from transformers import AutoModel, AutoConfig
class EndpointHandler:
def __init__(self, path=""):
# Explicitly trust remote code to load modeling_enformer.py
self.config = AutoConfig.from_pretrained(path, trust_remote_code=True)
self.model = AutoModel.from_pretrained(path, trust_remote_code=True)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
self.model.to(self.device).eval()
def __call__(self, data):
sequence = data.get("inputs", "")
target_len = 196608
# Ensure sequence is the correct length
seq = sequence[:target_len].ljust(target_len, 'N')
# One-hot encoding mapping
mapping = {'A': [1,0,0,0], 'C': [0,1,0,0], 'G': [0,0,1,0], 'T': [0,0,0,1]}
one_hot = np.array([mapping.get(base.upper(), [0,0,0,0]) for base in seq], dtype=np.float32)
# Convert to tensor [Batch, Length, Channels]
inputs = torch.from_numpy(one_hot).unsqueeze(0).to(self.device)
with torch.no_grad():
output = self.model(inputs)
# Extract human head predictions: [1, 896, 5313]
human_out = output['human']
# Target relevant tracks for APRIL:
# 4479 (B-cell CAGE), 4828 (HUVEC CAGE), 5111 (K562)
target_tracks = [4479, 4828, 5111]
# Select only these tracks to keep the response size manageable
result = human_out[:, :, target_tracks]
# Return as a nested list
return result.cpu().numpy().tolist()