Fine-tuned ESM2 Protein Classifier (edipropred)

This repository contains a fine-tuned ESM2 model for protein sequence classification, specifically the model uploaded to pushpendrag/edipropred. The model is trained to predict binary labels based on protein sequences.

Model Description

  • Base Model: ESM2-t33-650M-UR50D (Fine-tuned)
  • Fine-tuning Task: Binary protein classification.
  • Architecture: The model consists of the ESM2 backbone with a linear classification head.
  • Input: Protein amino acid sequences.
  • Output: Binary classification labels (0 or 1).

Repository Contents

  • final_full_model_object.pth: The trained model.
  • esm_alphabet.pth: The ESM2 alphabet (used as a tokenizer).
  • README.md: This file.

Usage

Installation

  1. Install the required libraries:

    pip install torch esm tourch huggingface_hub
    

Loading the Model from Hugging Face

import argparse
import os
import re
import warnings
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from tqdm import tqdm
from huggingface_hub import hf_hub_download

warnings.filterwarnings("ignore")

# =========================================================
# Device
# =========================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# =========================================================
# Hugging Face repository (MODEL SOURCE)
# =========================================================
HF_REPO_ID = "pushpendrag/edipropred"
HF_MODEL_FILE = "final_full_model_object.pth"
HF_ALPHABET_FILE = "esm_alphabet.pth"

# Local cache directory
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
CACHE_DIR = os.path.join(BASE_DIR, "hf_cache")
os.makedirs(CACHE_DIR, exist_ok=True)

# =========================================================
# Model definition (MUST match trained model)
# =========================================================
class ProteinClassifier(nn.Module):
    def __init__(self, esm_model, embedding_dim, num_classes):
        super().__init__()
        self.esm_model = esm_model
        self.fc = nn.Linear(embedding_dim, num_classes)

    def forward(self, tokens):
        with torch.no_grad():
            out = self.esm_model(tokens, repr_layers=[33])
        emb = out["representations"][33].mean(1)
        return self.fc(emb)

# =========================================================
# Download model files from Hugging Face
# =========================================================
print("๐Ÿ”„ Checking / downloading model files from Hugging Face...")

MODEL_PATH = hf_hub_download(
    repo_id=HF_REPO_ID,
    filename=HF_MODEL_FILE,
    cache_dir=CACHE_DIR
)

ALPHABET_PATH = hf_hub_download(
    repo_id=HF_REPO_ID,
    filename=HF_ALPHABET_FILE,
    cache_dir=CACHE_DIR
)

# =========================================================
# Load model + alphabet
# =========================================================
print("๐Ÿ”„ Loading trained ESM2-t33 model...")

alphabet = torch.load(ALPHABET_PATH, map_location="cpu", weights_only=False)
batch_converter = alphabet.get_batch_converter()

classifier = torch.load(MODEL_PATH, map_location=device, weights_only=False)
classifier = classifier.to(device)
classifier.eval()

print("โœ… Model loaded successfully")
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support

Model tree for raghavagps-group/edipropred

Finetuned
(29)
this model