Noursine commited on
Commit
baa5379
·
verified ·
1 Parent(s): d5f63b0

Create main.py

Browse files
Files changed (1) hide show
  1. main.py +63 -0
main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+
3
+ from fastapi import FastAPI, File, UploadFile, Form
4
+ from fastapi.responses import JSONResponse
5
+ import uvicorn
6
+ import shutil
7
+ import os
8
+ import uuid
9
+ import cv2
10
+ import numpy as np
11
+ import base64
12
+
13
+ from detectron_infer import predict as detectron_predict
14
+ from yolo_infer import predict_yolo
15
+
16
+ app = FastAPI()
17
+
18
+ UPLOAD_DIR = "/tmp/uploads"
19
+ os.makedirs(UPLOAD_DIR, exist_ok=True)
20
+ @app.get("/")
21
+ async def root():
22
+ return {"message": "API is up and running!"}
23
+
24
+ @app.post("/predict/")
25
+ async def predict_endpoint(
26
+ file: UploadFile = File(...),
27
+ model: str = Form(...), # either "detectron" or "yolo"
28
+ ):
29
+ # Save uploaded image
30
+ file_ext = file.filename.split('.')[-1]
31
+ filename = f"{uuid.uuid4()}.{file_ext}"
32
+ file_path = os.path.join(UPLOAD_DIR, filename)
33
+
34
+ with open(file_path, "wb") as buffer:
35
+ shutil.copyfileobj(file.file, buffer)
36
+
37
+ try:
38
+ if model == "Anomaly":
39
+ result_img, predictions = detectron_predict(file_path)
40
+ elif model == "Numbering":
41
+ result_img, predictions = predict_yolo(file_path)
42
+ else:
43
+ return JSONResponse({"error": "Invalid model choice"}, status_code=400)
44
+
45
+ # Encode image to bytes (optional)
46
+ _, img_encoded = cv2.imencode(".jpg", result_img)
47
+ img_bytes = img_encoded.tobytes()
48
+
49
+ img_b64 = base64.b64encode(img_bytes).decode('utf-8')
50
+
51
+ return {
52
+ "predictions": predictions,
53
+ "image_base64": img_b64
54
+ }
55
+ except Exception as e:
56
+ return JSONResponse({"error": str(e)}, status_code=500)
57
+ finally:
58
+ os.remove(file_path) # Clean up uploaded file
59
+
60
+
61
+ # Only for testing locally
62
+ if __name__ == "__main__":
63
+ uvicorn.run(app, host="0.0.0.0", port=8000)