Sync from GitHub (preserve manual model files)
Browse files
StreamlitApp/utils/predict.py
CHANGED
|
@@ -4,14 +4,12 @@ import numpy as np
|
|
| 4 |
import torch
|
| 5 |
import streamlit as st
|
| 6 |
from torch import nn
|
| 7 |
-
from transformers import
|
| 8 |
-
|
| 9 |
|
| 10 |
MODEL_INPUT_DIM = 1024
|
| 11 |
MODEL_ARCH = "FastMLP"
|
| 12 |
PROTBERT_MODEL_NAME = "Rostlab/prot_bert"
|
| 13 |
|
| 14 |
-
|
| 15 |
class FastMLP(nn.Module):
|
| 16 |
def __init__(self, input_dim=MODEL_INPUT_DIM):
|
| 17 |
super(FastMLP, self).__init__()
|
|
@@ -93,7 +91,8 @@ def load_model():
|
|
| 93 |
# Use an explicit slow tokenizer to avoid fast-backend conversion issues on Spaces.
|
| 94 |
tokenizer = BertTokenizer.from_pretrained(PROTBERT_MODEL_NAME, do_lower_case=False)
|
| 95 |
|
| 96 |
-
|
|
|
|
| 97 |
encoder.eval()
|
| 98 |
|
| 99 |
return {
|
|
|
|
| 4 |
import torch
|
| 5 |
import streamlit as st
|
| 6 |
from torch import nn
|
| 7 |
+
from transformers import BertModel, BertTokenizer
|
|
|
|
| 8 |
|
| 9 |
MODEL_INPUT_DIM = 1024
|
| 10 |
MODEL_ARCH = "FastMLP"
|
| 11 |
PROTBERT_MODEL_NAME = "Rostlab/prot_bert"
|
| 12 |
|
|
|
|
| 13 |
class FastMLP(nn.Module):
|
| 14 |
def __init__(self, input_dim=MODEL_INPUT_DIM):
|
| 15 |
super(FastMLP, self).__init__()
|
|
|
|
| 91 |
# Use an explicit slow tokenizer to avoid fast-backend conversion issues on Spaces.
|
| 92 |
tokenizer = BertTokenizer.from_pretrained(PROTBERT_MODEL_NAME, do_lower_case=False)
|
| 93 |
|
| 94 |
+
# Use explicit BERT class to avoid AutoModel config auto-detection issues.
|
| 95 |
+
encoder = BertModel.from_pretrained(PROTBERT_MODEL_NAME).to(device)
|
| 96 |
encoder.eval()
|
| 97 |
|
| 98 |
return {
|