Hayloo9838 commited on
Commit
0067c59
·
verified ·
1 Parent(s): 3442a64

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -0
app.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForImageClassification
5
+ from PIL import Image
6
+ import torchvision.transforms as T
7
+ import requests
8
+ from io import BytesIO
9
+
10
+ app = FastAPI()
11
+
12
+ # load model once
13
+ model_name = "nateraw/vit-base-patch16-224-in21k"
14
+ model = AutoModelForImageClassification.from_pretrained(model_name)
15
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
16
+ transform = T.Compose([
17
+ T.Resize((224, 224)),
18
+ T.ToTensor(),
19
+ T.Normalize([0.5], [0.5])
20
+ ])
21
+
22
+ class ImageInput(BaseModel):
23
+ url: str
24
+
25
+ @app.get("/")
26
+ def read_root():
27
+ return {"status": "running"}
28
+
29
+ @app.post("/predict")
30
+ def predict(input: ImageInput):
31
+ img = Image.open(BytesIO(requests.get(input.url).content)).convert("RGB")
32
+ img_tensor = transform(img).unsqueeze(0)
33
+
34
+ with torch.no_grad():
35
+ logits = model(img_tensor).logits
36
+ pred = torch.argmax(logits, dim=1).item()
37
+
38
+ return {"class": int(pred)}