Spaces:
Sleeping
Sleeping
ThanaritKanjanametawat commited on
Commit ·
582b2f2
1
Parent(s): 2287a5c
change the device to cpu only 3
Browse files- ModelDriver.py +3 -2
ModelDriver.py
CHANGED
|
@@ -2,6 +2,7 @@ from transformers import RobertaTokenizer, RobertaForSequenceClassification, Rob
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
|
|
|
|
| 5 |
device = torch.device("cpu")
|
| 6 |
class MLP(nn.Module):
|
| 7 |
def __init__(self, input_dim):
|
|
@@ -27,7 +28,7 @@ def extract_features(text):
|
|
| 27 |
def RobertaSentinelOpenGPTInference(input_text):
|
| 28 |
features = extract_features(input_text)
|
| 29 |
loaded_model = MLP(768).to(device)
|
| 30 |
-
loaded_model.load_state_dict(torch.load("MLPDictStates/RobertaSentinelOpenGPT.pth"))
|
| 31 |
|
| 32 |
# Define the tokenizer and model for feature extraction
|
| 33 |
with torch.no_grad():
|
|
@@ -40,7 +41,7 @@ def RobertaSentinelOpenGPTInference(input_text):
|
|
| 40 |
def RobertaSentinelCSAbstractInference(input_text):
|
| 41 |
features = extract_features(input_text)
|
| 42 |
loaded_model = MLP(768).to(device)
|
| 43 |
-
loaded_model.load_state_dict(torch.load("MLPDictStates/RobertaSentinelCSAbstract.pth"))
|
| 44 |
|
| 45 |
# Define the tokenizer and model for feature extraction
|
| 46 |
with torch.no_grad():
|
|
|
|
| 2 |
import torch
|
| 3 |
import torch.nn as nn
|
| 4 |
|
| 5 |
+
|
| 6 |
device = torch.device("cpu")
|
| 7 |
class MLP(nn.Module):
|
| 8 |
def __init__(self, input_dim):
|
|
|
|
| 28 |
def RobertaSentinelOpenGPTInference(input_text):
|
| 29 |
features = extract_features(input_text)
|
| 30 |
loaded_model = MLP(768).to(device)
|
| 31 |
+
loaded_model.load_state_dict(torch.load("MLPDictStates/RobertaSentinelOpenGPT.pth", map_location=device))
|
| 32 |
|
| 33 |
# Define the tokenizer and model for feature extraction
|
| 34 |
with torch.no_grad():
|
|
|
|
| 41 |
def RobertaSentinelCSAbstractInference(input_text):
|
| 42 |
features = extract_features(input_text)
|
| 43 |
loaded_model = MLP(768).to(device)
|
| 44 |
+
loaded_model.load_state_dict(torch.load("MLPDictStates/RobertaSentinelCSAbstract.pth", map_location=device))
|
| 45 |
|
| 46 |
# Define the tokenizer and model for feature extraction
|
| 47 |
with torch.no_grad():
|