Subh775 commited on
Commit
9608158
·
verified ·
1 Parent(s): 93f2355

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +108 -0
app.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from fastapi import FastAPI
7
+ from fastapi.responses import FileResponse
8
+ from pydantic import BaseModel
9
+ from PIL import Image
10
+ import segmentation_models_pytorch as smp
11
+ from huggingface_hub import hf_hub_download
12
+
13
+ # --- CONFIGURATION ---
14
+ HF_MODEL_REPO_ID = "LeafNet75/Leaf-Annotate-v2"
15
+ DEVICE = "cpu"
16
+ IMG_SIZE = 256
17
+
18
+ # --- DATA MODELS FOR API (using Pydantic) ---
19
+ class InferenceRequest(BaseModel):
20
+ image: str # base64 encoded image string
21
+ scribble_mask: str # base64 encoded scribble mask string
22
+
23
+ class InferenceResponse(BaseModel):
24
+ predicted_mask: str # base64 encoded predicted mask string
25
+
26
+ # --- INITIALIZE FASTAPI APP ---
27
+ app = FastAPI()
28
+
29
+ # --- LOAD MODEL ON STARTUP ---
30
+ # The model is loaded once when the application starts to ensure fast inference times.
31
+ def load_model():
32
+ print(f"Loading model '{HF_MODEL_REPO_ID}'...")
33
+ model_path = hf_hub_download(repo_id=HF_MODEL_REPO_ID, filename="best_model.pth")
34
+
35
+ model = smp.Unet(
36
+ encoder_name="mobilenet_v2",
37
+ encoder_weights=None,
38
+ in_channels=4,
39
+ classes=1,
40
+ )
41
+ model.load_state_dict(torch.load(model_path, map_location=DEVICE))
42
+ model.to(DEVICE)
43
+ model.eval()
44
+ print("Model loaded successfully.")
45
+ return model
46
+
47
+ model = load_model()
48
+
49
+ # --- HELPER FUNCTIONS ---
50
+ def base64_to_cv2(base64_string: str):
51
+ # Remove the "data:image/..." header
52
+ header, encoded = base64_string.split(",", 1)
53
+ img_data = base64.b64decode(encoded)
54
+
55
+ # Use Pillow to open the image data and convert to OpenCV format
56
+ pil_image = Image.open(io.BytesIO(img_data))
57
+ return cv2.cvtColor(np.array(pil_image), cv2.COLOR_RGBA2BGRA)
58
+
59
+ def cv2_to_base64(image: np.ndarray):
60
+ # Convert image back to a base64 string to send to the frontend
61
+ _, buffer = cv2.imencode('.png', image)
62
+ png_as_text = base64.b64encode(buffer).decode('utf-8')
63
+ return f"data:image/png;base64,{png_as_text}"
64
+
65
+
66
+ # --- API ENDPOINTS ---
67
+ @app.get("/")
68
+ def read_root():
69
+ # Serve the frontend HTML file
70
+ return FileResponse('index.html')
71
+
72
+ @app.post("/predict", response_model=InferenceResponse)
73
+ async def predict(request: InferenceRequest):
74
+ # 1. Decode input data
75
+ image_cv = base64_to_cv2(request.image)
76
+ scribble_cv = base64_to_cv2(request.scribble_mask)
77
+
78
+ # Ensure scribble is grayscale
79
+ if len(scribble_cv.shape) > 2 and scribble_cv.shape[2] > 1:
80
+ scribble_cv = cv2.cvtColor(scribble_cv, cv2.COLOR_BGRA2GRAY)
81
+
82
+ h, w, _ = image_cv.shape
83
+
84
+ # 2. Preprocess the data for the model
85
+ image_resized = cv2.resize(cv2.cvtColor(image_cv, cv2.COLOR_BGRA2RGB), (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_AREA)
86
+ scribble_resized = cv2.resize(scribble_cv, (IMG_SIZE, IMG_SIZE), interpolation=cv2.INTER_NEAREST)
87
+
88
+ image_tensor = torch.from_numpy(image_resized.astype(np.float32)).permute(2, 0, 1) / 255.0
89
+ scribble_tensor = torch.from_numpy(scribble_resized.astype(np.float32)).unsqueeze(0) / 255.0
90
+
91
+ input_tensor = torch.cat([image_tensor, scribble_tensor], dim=0).unsqueeze(0).to(DEVICE)
92
+
93
+ # 3. Run Inference
94
+ with torch.no_grad():
95
+ output = model(input_tensor)
96
+
97
+ # 4. Post-process the output
98
+ probs = torch.sigmoid(output)
99
+ binary_mask = (probs > 0.5).float().squeeze().cpu().numpy()
100
+
101
+ # Resize mask to the original input canvas size
102
+ output_mask_resized = cv2.resize(binary_mask, (w, h), interpolation=cv2.INTER_NEAREST)
103
+ output_mask_uint8 = (output_mask_resized * 255).astype(np.uint8)
104
+
105
+ # 5. Encode the result and return
106
+ result_base64 = cv2_to_base64(output_mask_uint8)
107
+
108
+ return InferenceResponse(predicted_mask=result_base64)