pranit144 commited on
Commit
ed57c13
·
verified ·
1 Parent(s): ddc8ff7

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +43 -43
main.py CHANGED
@@ -1,43 +1,43 @@
1
- from fastapi import FastAPI, UploadFile, File, HTTPException
2
- from vit_classifier import ViTClassifier
3
- import io
4
- import uvicorn
5
-
6
- app = FastAPI(title="ViT Model Deployment")
7
-
8
- # Initialize model on startup (or lazy load)
9
- # We'll rely on the singleton pattern in ViTClassifier to load it once.
10
- print("Initializing model...")
11
- ViTClassifier.get_instance()
12
-
13
- @app.get("/")
14
- def read_root():
15
- return {"status": "online", "message": "ViT Model API is running"}
16
-
17
- @app.post("/predict")
18
- async def predict(file: UploadFile = File(...)):
19
- if not file.content_type.startswith("image/"):
20
- raise HTTPException(status_code=400, detail="File must be an image")
21
-
22
- try:
23
- # Read image data
24
- image_data = await file.read()
25
- image_file = io.BytesIO(image_data)
26
-
27
- # Get prediction
28
- classifier = ViTClassifier.get_instance()
29
- predicted_class, confidence = classifier.predict(image_file)
30
-
31
- if predicted_class is None:
32
- raise HTTPException(status_code=500, detail="Model failed to predict")
33
-
34
- return {
35
- "status": "success",
36
- "prediction": predicted_class,
37
- "confidence": confidence
38
- }
39
- except Exception as e:
40
- raise HTTPException(status_code=500, detail=str(e))
41
-
42
- if __name__ == "__main__":
43
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI, UploadFile, File, HTTPException
2
+ from vit_classifier import ViTClassifier
3
+ import io
4
+ import uvicorn
5
+
6
+ app = FastAPI(title="ViT Model Deployment")
7
+
8
+ # Initialize model on startup
9
+ print("Initializing model...")
10
+ ViTClassifier.get_instance()
11
+
12
+ @app.get("/")
13
+ def read_root():
14
+ return {"status": "online", "message": "ViT Model API is running"}
15
+
16
+ @app.post("/predict")
17
+ async def predict(file: UploadFile = File(...)):
18
+ if not file.content_type.startswith("image/"):
19
+ raise HTTPException(status_code=400, detail="File must be an image")
20
+
21
+ try:
22
+ image_data = await file.read()
23
+ image_file = io.BytesIO(image_data)
24
+
25
+ classifier = ViTClassifier.get_instance()
26
+ predicted_class, confidence, all_probs = classifier.predict(image_file)
27
+
28
+ if predicted_class is None:
29
+ raise HTTPException(status_code=500, detail="Model failed to predict")
30
+
31
+ return {
32
+ "status": "success",
33
+ "prediction": predicted_class,
34
+ "confidence": confidence,
35
+ "probabilities": all_probs
36
+ }
37
+
38
+ except Exception as e:
39
+ raise HTTPException(status_code=500, detail=str(e))
40
+
41
+
42
+ if __name__ == "__main__":
43
+ uvicorn.run(app, host="0.0.0.0", port=8000)