m0ksh commited on
Commit
ebee819
·
verified ·
1 Parent(s): e51c298

Sync from GitHub (preserve manual model files)

Browse files
Files changed (1) hide show
  1. StreamlitApp/utils/predict.py +3 -4
StreamlitApp/utils/predict.py CHANGED
@@ -4,14 +4,12 @@ import numpy as np
4
  import torch
5
  import streamlit as st
6
  from torch import nn
7
- from transformers import AutoModel, BertTokenizer
8
-
9
 
10
  MODEL_INPUT_DIM = 1024
11
  MODEL_ARCH = "FastMLP"
12
  PROTBERT_MODEL_NAME = "Rostlab/prot_bert"
13
 
14
-
15
  class FastMLP(nn.Module):
16
  def __init__(self, input_dim=MODEL_INPUT_DIM):
17
  super(FastMLP, self).__init__()
@@ -93,7 +91,8 @@ def load_model():
93
  # Use an explicit slow tokenizer to avoid fast-backend conversion issues on Spaces.
94
  tokenizer = BertTokenizer.from_pretrained(PROTBERT_MODEL_NAME, do_lower_case=False)
95
 
96
- encoder = AutoModel.from_pretrained(PROTBERT_MODEL_NAME).to(device)
 
97
  encoder.eval()
98
 
99
  return {
 
4
  import torch
5
  import streamlit as st
6
  from torch import nn
7
+ from transformers import BertModel, BertTokenizer
 
8
 
9
  MODEL_INPUT_DIM = 1024
10
  MODEL_ARCH = "FastMLP"
11
  PROTBERT_MODEL_NAME = "Rostlab/prot_bert"
12
 
 
13
  class FastMLP(nn.Module):
14
  def __init__(self, input_dim=MODEL_INPUT_DIM):
15
  super(FastMLP, self).__init__()
 
91
  # Use an explicit slow tokenizer to avoid fast-backend conversion issues on Spaces.
92
  tokenizer = BertTokenizer.from_pretrained(PROTBERT_MODEL_NAME, do_lower_case=False)
93
 
94
+ # Use explicit BERT class to avoid AutoModel config auto-detection issues.
95
+ encoder = BertModel.from_pretrained(PROTBERT_MODEL_NAME).to(device)
96
  encoder.eval()
97
 
98
  return {