ahadhassan commited on
Commit
eeac6a0
·
verified ·
1 Parent(s): bc3ee4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -12
app.py CHANGED
@@ -1,51 +1,144 @@
1
  from fastapi.responses import StreamingResponse, JSONResponse
2
- from fastapi import FastAPI, File, UploadFile
3
  from ndvi_predictor import load_model, normalize_rgb, predict_ndvi, create_visualization
 
4
  from PIL import Image
5
  from io import BytesIO
6
  import numpy as np
7
  import zipfile
 
 
 
 
 
8
 
9
  app = FastAPI()
10
- model = load_model("ndvi_best_model.keras")
 
 
 
11
 
12
  @app.get("/")
13
  async def root():
14
- return {"message": "Welcome to the NDVI prediction API!"}
15
 
16
- @app.post("/predict/")
17
  async def predict_ndvi_api(file: UploadFile = File(...)):
 
18
  try:
19
  contents = await file.read()
20
  img = Image.open(BytesIO(contents)).convert("RGB")
21
-
22
  norm_img = normalize_rgb(np.array(img))
23
- pred_ndvi = predict_ndvi(model, norm_img)
24
-
25
  # Visualization image as PNG
26
  vis_img_bytes = create_visualization(norm_img, pred_ndvi)
27
  vis_img_bytes.seek(0)
28
-
29
  # NDVI band as .npy
30
  ndvi_bytes = BytesIO()
31
  np.save(ndvi_bytes, pred_ndvi)
32
  ndvi_bytes.seek(0)
33
-
34
  # Create a ZIP containing both files
35
-
36
  zip_buf = BytesIO()
37
  with zipfile.ZipFile(zip_buf, "w") as zip_file:
38
  zip_file.writestr("ndvi_image.png", vis_img_bytes.read())
39
  ndvi_bytes.seek(0)
40
  zip_file.writestr("ndvi_band.npy", ndvi_bytes.read())
41
-
42
  zip_buf.seek(0)
43
-
44
  return StreamingResponse(
45
  zip_buf,
46
  media_type="application/x-zip-compressed",
47
  headers={"Content-Disposition": "attachment; filename=ndvi_output.zip"}
48
  )
 
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  except Exception as e:
51
  return JSONResponse(status_code=500, content={"error": str(e)})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi.responses import StreamingResponse, JSONResponse
2
+ from fastapi import FastAPI, File, UploadFile, HTTPException
3
  from ndvi_predictor import load_model, normalize_rgb, predict_ndvi, create_visualization
4
+ from yolo_predictor import load_yolo_model, predict_yolo, predict_pipeline
5
  from PIL import Image
6
  from io import BytesIO
7
  import numpy as np
8
  import zipfile
9
+ import json
10
+ import rasterio
11
+ from rasterio.transform import from_bounds
12
+ import tempfile
13
+ import os
14
 
15
  app = FastAPI()
16
+
17
+ # Load models at startup
18
+ ndvi_model = load_model("ndvi_best_model.keras")
19
+ yolo_model = load_yolo_model("4c_6c_regression.pt")
20
 
21
  @app.get("/")
22
  async def root():
23
+ return {"message": "Welcome to the NDVI and YOLO prediction API!"}
24
 
25
+ @app.post("/predict_ndvi/")
26
  async def predict_ndvi_api(file: UploadFile = File(...)):
27
+ """Predict NDVI from RGB image"""
28
  try:
29
  contents = await file.read()
30
  img = Image.open(BytesIO(contents)).convert("RGB")
 
31
  norm_img = normalize_rgb(np.array(img))
32
+ pred_ndvi = predict_ndvi(ndvi_model, norm_img)
33
+
34
  # Visualization image as PNG
35
  vis_img_bytes = create_visualization(norm_img, pred_ndvi)
36
  vis_img_bytes.seek(0)
37
+
38
  # NDVI band as .npy
39
  ndvi_bytes = BytesIO()
40
  np.save(ndvi_bytes, pred_ndvi)
41
  ndvi_bytes.seek(0)
42
+
43
  # Create a ZIP containing both files
 
44
  zip_buf = BytesIO()
45
  with zipfile.ZipFile(zip_buf, "w") as zip_file:
46
  zip_file.writestr("ndvi_image.png", vis_img_bytes.read())
47
  ndvi_bytes.seek(0)
48
  zip_file.writestr("ndvi_band.npy", ndvi_bytes.read())
49
+
50
  zip_buf.seek(0)
 
51
  return StreamingResponse(
52
  zip_buf,
53
  media_type="application/x-zip-compressed",
54
  headers={"Content-Disposition": "attachment; filename=ndvi_output.zip"}
55
  )
56
+ except Exception as e:
57
+ return JSONResponse(status_code=500, content={"error": str(e)})
58
 
59
+ @app.post("/predict_yolo/")
60
+ async def predict_yolo_api(file: UploadFile = File(...)):
61
+ """Predict YOLO results from 4-channel TIFF image"""
62
+ try:
63
+ # Save uploaded file temporarily
64
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
65
+ contents = await file.read()
66
+ tmp_file.write(contents)
67
+ tmp_file_path = tmp_file.name
68
+
69
+ try:
70
+ # Predict using YOLO model
71
+ results = predict_yolo(yolo_model, tmp_file_path)
72
+
73
+ # Convert results to JSON-serializable format
74
+ results_dict = {
75
+ "boxes": {
76
+ "xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None,
77
+ "conf": results.boxes.conf.tolist() if results.boxes is not None else None,
78
+ "cls": results.boxes.cls.tolist() if results.boxes is not None else None
79
+ },
80
+ "classes": results.boxes.cls.tolist() if results.boxes is not None else None,
81
+ "names": results.names,
82
+ "growth_stages": getattr(results, 'growth_stages', None),
83
+ "orig_shape": results.orig_shape,
84
+ "speed": results.speed
85
+ }
86
+
87
+ # Handle growth stages if present
88
+ if hasattr(results, 'boxes') and hasattr(results.boxes, 'data'):
89
+ # Extract growth stages from the results if available
90
+ if len(results.boxes.data[0]) > 6: # Assuming growth stages are in the data
91
+ growth_stages = results.boxes.data[:, 6].tolist()
92
+ results_dict["growth_stages"] = growth_stages
93
+
94
+ return JSONResponse(content=results_dict)
95
+
96
+ finally:
97
+ # Clean up temporary file
98
+ os.unlink(tmp_file_path)
99
+
100
  except Exception as e:
101
  return JSONResponse(status_code=500, content={"error": str(e)})
102
+
103
+ @app.post("/predict_pipeline/")
104
+ async def predict_pipeline_api(file: UploadFile = File(...)):
105
+ """Full pipeline: RGB -> NDVI -> 4-channel -> YOLO prediction"""
106
+ try:
107
+ # Save uploaded file temporarily
108
+ with tempfile.NamedTemporaryFile(delete=False, suffix='.tif') as tmp_file:
109
+ contents = await file.read()
110
+ tmp_file.write(contents)
111
+ tmp_file_path = tmp_file.name
112
+
113
+ try:
114
+ # Run the full pipeline
115
+ results = predict_pipeline(ndvi_model, yolo_model, tmp_file_path)
116
+
117
+ # Convert results to JSON-serializable format
118
+ results_dict = {
119
+ "boxes": {
120
+ "xyxyn": results.boxes.xyxyn.tolist() if results.boxes is not None else None,
121
+ "conf": results.boxes.conf.tolist() if results.boxes is not None else None,
122
+ "cls": results.boxes.cls.tolist() if results.boxes is not None else None
123
+ },
124
+ "classes": results.boxes.cls.tolist() if results.boxes is not None else None,
125
+ "names": results.names,
126
+ "growth_stages": getattr(results, 'growth_stages', None),
127
+ "orig_shape": results.orig_shape,
128
+ "speed": results.speed
129
+ }
130
+
131
+ # Handle growth stages if present
132
+ if hasattr(results, 'boxes') and hasattr(results.boxes, 'data'):
133
+ if len(results.boxes.data[0]) > 6:
134
+ growth_stages = results.boxes.data[:, 6].tolist()
135
+ results_dict["growth_stages"] = growth_stages
136
+
137
+ return JSONResponse(content=results_dict)
138
+
139
+ finally:
140
+ # Clean up temporary file
141
+ os.unlink(tmp_file_path)
142
+
143
+ except Exception as e:
144
+ return JSONResponse(status_code=500, content={"error": str(e)})