JaySenpai commited on
Commit
9fe8d2b
·
verified ·
1 Parent(s): e8beaae

Update app/main.py

Browse files
Files changed (1) hide show
  1. app/main.py +8 -11
app/main.py CHANGED
@@ -1,24 +1,22 @@
1
  import os
2
-
3
- # Set a writable cache directory inside the Space container
4
- os.environ["TRANSFORMERS_CACHE"] = "/code/cache"
5
-
6
-
7
  from fastapi import FastAPI
8
  from pydantic import BaseModel
9
- from transformers import BertTokenizer, BertForSequenceClassification
10
  from sklearn.preprocessing import LabelEncoder
11
  import torch
12
  import numpy as np
13
 
 
 
 
14
  app = FastAPI()
15
 
16
- model = BertForSequenceClassification.from_pretrained(
17
  "JaySenpai/bert-model",
18
  cache_dir="/code/cache",
19
- use_safetensors=True
20
  )
21
- tokenizer = BertTokenizer.from_pretrained(
22
  "JaySenpai/bert-model",
23
  cache_dir="/code/cache"
24
  )
@@ -32,8 +30,7 @@ class TextInput(BaseModel):
32
 
33
  @app.post("/predict")
34
  async def predict(data: TextInput):
35
- text = data.text
36
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
37
  with torch.no_grad():
38
  outputs = model(**inputs)
39
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
 
1
  import os
 
 
 
 
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
5
  from sklearn.preprocessing import LabelEncoder
6
  import torch
7
  import numpy as np
8
 
9
+ # Set cache
10
+ os.environ["TRANSFORMERS_CACHE"] = "/code/cache"
11
+
12
  app = FastAPI()
13
 
14
+ model = AutoModelForSequenceClassification.from_pretrained(
15
  "JaySenpai/bert-model",
16
  cache_dir="/code/cache",
17
+ use_safetensors=True
18
  )
19
+ tokenizer = AutoTokenizer.from_pretrained(
20
  "JaySenpai/bert-model",
21
  cache_dir="/code/cache"
22
  )
 
30
 
31
  @app.post("/predict")
32
  async def predict(data: TextInput):
33
+ inputs = tokenizer(data.text, return_tensors="pt", truncation=True, padding=True)
 
34
  with torch.no_grad():
35
  outputs = model(**inputs)
36
  probs = torch.nn.functional.softmax(outputs.logits, dim=-1)