Brightsun10 commited on
Commit
1dde15e
·
verified ·
1 Parent(s): d30a82c

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +62 -59
main.py CHANGED
@@ -1,59 +1,62 @@
1
- # app/main.py
2
- from fastapi import FastAPI, UploadFile, File
3
- from fastapi.middleware.cors import CORSMiddleware
4
- from fastapi.responses import JSONResponse, FileResponse
5
- from fastapi.staticfiles import StaticFiles
6
- from PIL import Image
7
- import io
8
- import torch
9
- import torchvision.transforms as transforms
10
- from torchvision import models
11
-
12
- app = FastAPI()
13
-
14
- # Enable CORS
15
- app.add_middleware(
16
- CORSMiddleware,
17
- allow_origins=["*"],
18
- allow_credentials=True,
19
- allow_methods=["*"],
20
- allow_headers=["*"],
21
- )
22
-
23
- app.mount("/static", StaticFiles(directory="/static"), name="static")
24
-
25
- @app.get("/")
26
- def read_root():
27
- return FileResponse("/static/index.html")
28
-
29
- # Load pretrained model (use Hugging Face internet to download)
30
- model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
31
- model.eval()
32
-
33
- with open("/imagenet_classes.txt") as f:
34
- labels = [line.strip() for line in f]
35
-
36
- transform = transforms.Compose([
37
- transforms.Resize(256),
38
- transforms.CenterCrop(224),
39
- transforms.ToTensor(),
40
- transforms.Normalize(
41
- mean=[0.485, 0.456, 0.406],
42
- std=[0.229, 0.224, 0.225]),
43
- ])
44
-
45
- @app.post("/predict")
46
- async def predict(file: UploadFile = File(...)):
47
- try:
48
- image_bytes = await file.read()
49
- img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
50
- img_tensor = transform(img).unsqueeze(0)
51
-
52
- with torch.no_grad():
53
- outputs = model(img_tensor)
54
- _, predicted = torch.max(outputs, 1)
55
- label = labels[predicted.item()]
56
-
57
- return JSONResponse(content={"prediction": label})
58
- except Exception as e:
59
- return JSONResponse(content={"error": str(e)}, status_code=500)
 
 
 
 
1
+ # app/main.py
2
+ from fastapi import FastAPI, UploadFile, File
3
+ from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse, FileResponse
5
+ from fastapi.staticfiles import StaticFiles
6
+ from PIL import Image
7
+ import io
8
+ import torch
9
+ import torchvision.transforms as transforms
10
+ from torchvision import models
11
+
12
+ app = FastAPI()
13
+
14
+ # Enable CORS
15
+ app.add_middleware(
16
+ CORSMiddleware,
17
+ allow_origins=["*"],
18
+ allow_credentials=True,
19
+ allow_methods=["*"],
20
+ allow_headers=["*"],
21
+ )
22
+
23
+ # Corrected path to relative static directory
24
+ app.mount("/static", StaticFiles(directory="static"), name="static")
25
+
26
+ @app.get("/")
27
+ def read_root():
28
+ return FileResponse("static/index.html")
29
+
30
+ # Load pretrained model
31
+ model = models.resnet50(weights=models.ResNet50_Weights.DEFAULT)
32
+ model.eval()
33
+
34
+ # Load labels
35
+ with open("imagenet_classes.txt") as f:
36
+ labels = [line.strip() for line in f]
37
+
38
+ # Define image transforms
39
+ transform = transforms.Compose([
40
+ transforms.Resize(256),
41
+ transforms.CenterCrop(224),
42
+ transforms.ToTensor(),
43
+ transforms.Normalize(
44
+ mean=[0.485, 0.456, 0.406],
45
+ std=[0.229, 0.224, 0.225]),
46
+ ])
47
+
48
+ @app.post("/predict")
49
+ async def predict(file: UploadFile = File(...)):
50
+ try:
51
+ image_bytes = await file.read()
52
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
53
+ img_tensor = transform(img).unsqueeze(0)
54
+
55
+ with torch.no_grad():
56
+ outputs = model(img_tensor)
57
+ _, predicted = torch.max(outputs, 1)
58
+ label = labels[predicted.item()]
59
+
60
+ return JSONResponse(content={"prediction": label})
61
+ except Exception as e:
62
+ return JSONResponse(content={"error": str(e)}, status_code=500)