sukhmani1303 commited on
Commit
aee217c
·
verified ·
1 Parent(s): af3072b

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +51 -43
app.py CHANGED
@@ -59,20 +59,54 @@ def initialize_classifier():
59
  logger.error(f"Failed to initialize classifier: {str(e)}")
60
  raise RuntimeError(f"Failed to initialize classifier: {str(e)}")
61
 
62
- # Validate image
63
- def validate_image(image):
64
- if image is None:
65
- return False, "No image provided"
66
- if not isinstance(image, np.ndarray):
67
- image = np.array(image)
68
- if image.shape[0] < 10 or image.shape[1] < 10:
69
- return False, "Image too small"
70
- return True, image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  # Core prediction function
73
  def predict_core(image, return_raw=False):
74
  try:
75
- is_valid, result = validate_image(image)
76
  if not is_valid:
77
  logger.error(f"Invalid image: {result}")
78
  return {"error": result, "status": "failed"}
@@ -80,11 +114,7 @@ def predict_core(image, return_raw=False):
80
  image = result
81
  logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
82
 
83
- # Ensure image is in RGB format
84
- if len(image.shape) == 3 and image.shape[2] == 3:
85
- # Convert RGB to BGR for OpenCV
86
- image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
87
-
88
  result = classifier.predict(image, return_probs=True)
89
  logger.info(f"Raw prediction result: {result}")
90
 
@@ -106,22 +136,19 @@ def predict_core(image, return_raw=False):
106
  logger.info(f"Prediction output: {output}")
107
  return output
108
  except Exception as e:
109
- logger.error(f"Prediction failed: {str(e)}\n{traceback.format_exc()}")
110
- return {"error": f"{str(e)}\n{traceback.format_exc()}", "status": "failed"}
 
111
 
112
  # UI prediction function
113
  def predict(image):
114
  if image is None:
115
  return "Error: No image provided"
116
 
117
- # Convert PIL Image to numpy array
118
- if hasattr(image, 'convert'):
119
- image = image.convert('RGB')
120
- image = np.array(image)
121
-
122
  result = predict_core(image)
123
  if result["status"] == "failed":
124
  return f"Error: {result['error']}"
 
125
  return (
126
  f"**Classified as**: {result['prediction']}\n"
127
  f"**Confidence**: {result['confidence']:.4f}\n"
@@ -135,11 +162,6 @@ def predict_raw(image):
135
  if image is None:
136
  return json.dumps({"error": "No image provided", "status": "failed"}, indent=2)
137
 
138
- # Convert PIL Image to numpy array
139
- if hasattr(image, 'convert'):
140
- image = image.convert('RGB')
141
- image = np.array(image)
142
-
143
  result = predict_core(image, return_raw=True)
144
  return json.dumps(result, indent=2)
145
 
@@ -156,21 +178,14 @@ app = FastAPI()
156
  @app.post("/api/predict/")
157
  async def predict_api(file: UploadFile = File(...)):
158
  try:
159
- # Validate file type
160
  if not file.content_type.startswith('image/'):
161
  raise HTTPException(status_code=400, detail="File must be an image")
162
 
163
- # Read file contents
164
  contents = await file.read()
165
-
166
- # Convert to PIL Image first
167
  image = Image.open(io.BytesIO(contents))
168
  image = image.convert('RGB')
169
 
170
- # Convert to numpy array
171
- image_array = np.array(image)
172
-
173
- result = predict_core(image_array)
174
  return JSONResponse(content=result)
175
  except Exception as e:
176
  logger.error(f"API prediction failed: {str(e)}")
@@ -182,21 +197,14 @@ async def predict_api(file: UploadFile = File(...)):
182
  @app.post("/api/predict_raw/")
183
  async def predict_raw_api(file: UploadFile = File(...)):
184
  try:
185
- # Validate file type
186
  if not file.content_type.startswith('image/'):
187
  raise HTTPException(status_code=400, detail="File must be an image")
188
 
189
- # Read file contents
190
  contents = await file.read()
191
-
192
- # Convert to PIL Image first
193
  image = Image.open(io.BytesIO(contents))
194
  image = image.convert('RGB')
195
 
196
- # Convert to numpy array
197
- image_array = np.array(image)
198
-
199
- result = predict_core(image_array, return_raw=True)
200
  return JSONResponse(content=result)
201
  except Exception as e:
202
  logger.error(f"API raw prediction failed: {str(e)}")
 
59
  logger.error(f"Failed to initialize classifier: {str(e)}")
60
  raise RuntimeError(f"Failed to initialize classifier: {str(e)}")
61
 
62
+ # Enhanced image validation and preprocessing
63
+ def validate_and_preprocess_image(image):
64
+ try:
65
+ if image is None:
66
+ return False, "No image provided"
67
+
68
+ # Convert PIL to numpy if needed
69
+ if hasattr(image, 'convert'):
70
+ image = image.convert('RGB')
71
+ image = np.array(image)
72
+
73
+ if not isinstance(image, np.ndarray):
74
+ image = np.array(image)
75
+
76
+ # Basic validation
77
+ if len(image.shape) < 2:
78
+ return False, "Invalid image format"
79
+
80
+ if image.shape[0] < 10 or image.shape[1] < 10:
81
+ return False, "Image too small (minimum 10x10 pixels)"
82
+
83
+ # Ensure 3-channel RGB
84
+ if len(image.shape) == 2:
85
+ image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
86
+ elif len(image.shape) == 3 and image.shape[2] == 4:
87
+ image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
88
+ elif len(image.shape) == 3 and image.shape[2] == 3:
89
+ # Already RGB, but ensure proper format
90
+ pass
91
+ else:
92
+ return False, f"Unsupported image shape: {image.shape}"
93
+
94
+ # Normalize to 0-255 range if needed
95
+ if image.dtype == np.float32 or image.dtype == np.float64:
96
+ if image.max() <= 1.0:
97
+ image = (image * 255).astype(np.uint8)
98
+
99
+ logger.info(f"Preprocessed image shape: {image.shape}, dtype: {image.dtype}")
100
+ return True, image
101
+
102
+ except Exception as e:
103
+ logger.error(f"Image preprocessing failed: {str(e)}")
104
+ return False, f"Image preprocessing failed: {str(e)}"
105
 
106
  # Core prediction function
107
  def predict_core(image, return_raw=False):
108
  try:
109
+ is_valid, result = validate_and_preprocess_image(image)
110
  if not is_valid:
111
  logger.error(f"Invalid image: {result}")
112
  return {"error": result, "status": "failed"}
 
114
  image = result
115
  logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
116
 
117
+ # Use the classifier's predict method
 
 
 
 
118
  result = classifier.predict(image, return_probs=True)
119
  logger.info(f"Raw prediction result: {result}")
120
 
 
136
  logger.info(f"Prediction output: {output}")
137
  return output
138
  except Exception as e:
139
+ error_msg = f"Prediction failed: {str(e)}"
140
+ logger.error(f"{error_msg}\n{traceback.format_exc()}")
141
+ return {"error": error_msg, "status": "failed"}
142
 
143
  # UI prediction function
144
  def predict(image):
145
  if image is None:
146
  return "Error: No image provided"
147
 
 
 
 
 
 
148
  result = predict_core(image)
149
  if result["status"] == "failed":
150
  return f"Error: {result['error']}"
151
+
152
  return (
153
  f"**Classified as**: {result['prediction']}\n"
154
  f"**Confidence**: {result['confidence']:.4f}\n"
 
162
  if image is None:
163
  return json.dumps({"error": "No image provided", "status": "failed"}, indent=2)
164
 
 
 
 
 
 
165
  result = predict_core(image, return_raw=True)
166
  return json.dumps(result, indent=2)
167
 
 
178
  @app.post("/api/predict/")
179
  async def predict_api(file: UploadFile = File(...)):
180
  try:
 
181
  if not file.content_type.startswith('image/'):
182
  raise HTTPException(status_code=400, detail="File must be an image")
183
 
 
184
  contents = await file.read()
 
 
185
  image = Image.open(io.BytesIO(contents))
186
  image = image.convert('RGB')
187
 
188
+ result = predict_core(image)
 
 
 
189
  return JSONResponse(content=result)
190
  except Exception as e:
191
  logger.error(f"API prediction failed: {str(e)}")
 
197
  @app.post("/api/predict_raw/")
198
  async def predict_raw_api(file: UploadFile = File(...)):
199
  try:
 
200
  if not file.content_type.startswith('image/'):
201
  raise HTTPException(status_code=400, detail="File must be an image")
202
 
 
203
  contents = await file.read()
 
 
204
  image = Image.open(io.BytesIO(contents))
205
  image = image.convert('RGB')
206
 
207
+ result = predict_core(image, return_raw=True)
 
 
 
208
  return JSONResponse(content=result)
209
  except Exception as e:
210
  logger.error(f"API raw prediction failed: {str(e)}")