Spaces:
Build error
Build error
Upload app.py with huggingface_hub
Browse files
app.py
CHANGED
|
@@ -59,20 +59,54 @@ def initialize_classifier():
|
|
| 59 |
logger.error(f"Failed to initialize classifier: {str(e)}")
|
| 60 |
raise RuntimeError(f"Failed to initialize classifier: {str(e)}")
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
def
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
|
| 72 |
# Core prediction function
|
| 73 |
def predict_core(image, return_raw=False):
|
| 74 |
try:
|
| 75 |
-
is_valid, result =
|
| 76 |
if not is_valid:
|
| 77 |
logger.error(f"Invalid image: {result}")
|
| 78 |
return {"error": result, "status": "failed"}
|
|
@@ -80,11 +114,7 @@ def predict_core(image, return_raw=False):
|
|
| 80 |
image = result
|
| 81 |
logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
|
| 82 |
|
| 83 |
-
#
|
| 84 |
-
if len(image.shape) == 3 and image.shape[2] == 3:
|
| 85 |
-
# Convert RGB to BGR for OpenCV
|
| 86 |
-
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
| 87 |
-
|
| 88 |
result = classifier.predict(image, return_probs=True)
|
| 89 |
logger.info(f"Raw prediction result: {result}")
|
| 90 |
|
|
@@ -106,22 +136,19 @@ def predict_core(image, return_raw=False):
|
|
| 106 |
logger.info(f"Prediction output: {output}")
|
| 107 |
return output
|
| 108 |
except Exception as e:
|
| 109 |
-
|
| 110 |
-
|
|
|
|
| 111 |
|
| 112 |
# UI prediction function
|
| 113 |
def predict(image):
|
| 114 |
if image is None:
|
| 115 |
return "Error: No image provided"
|
| 116 |
|
| 117 |
-
# Convert PIL Image to numpy array
|
| 118 |
-
if hasattr(image, 'convert'):
|
| 119 |
-
image = image.convert('RGB')
|
| 120 |
-
image = np.array(image)
|
| 121 |
-
|
| 122 |
result = predict_core(image)
|
| 123 |
if result["status"] == "failed":
|
| 124 |
return f"Error: {result['error']}"
|
|
|
|
| 125 |
return (
|
| 126 |
f"**Classified as**: {result['prediction']}\n"
|
| 127 |
f"**Confidence**: {result['confidence']:.4f}\n"
|
|
@@ -135,11 +162,6 @@ def predict_raw(image):
|
|
| 135 |
if image is None:
|
| 136 |
return json.dumps({"error": "No image provided", "status": "failed"}, indent=2)
|
| 137 |
|
| 138 |
-
# Convert PIL Image to numpy array
|
| 139 |
-
if hasattr(image, 'convert'):
|
| 140 |
-
image = image.convert('RGB')
|
| 141 |
-
image = np.array(image)
|
| 142 |
-
|
| 143 |
result = predict_core(image, return_raw=True)
|
| 144 |
return json.dumps(result, indent=2)
|
| 145 |
|
|
@@ -156,21 +178,14 @@ app = FastAPI()
|
|
| 156 |
@app.post("/api/predict/")
|
| 157 |
async def predict_api(file: UploadFile = File(...)):
|
| 158 |
try:
|
| 159 |
-
# Validate file type
|
| 160 |
if not file.content_type.startswith('image/'):
|
| 161 |
raise HTTPException(status_code=400, detail="File must be an image")
|
| 162 |
|
| 163 |
-
# Read file contents
|
| 164 |
contents = await file.read()
|
| 165 |
-
|
| 166 |
-
# Convert to PIL Image first
|
| 167 |
image = Image.open(io.BytesIO(contents))
|
| 168 |
image = image.convert('RGB')
|
| 169 |
|
| 170 |
-
|
| 171 |
-
image_array = np.array(image)
|
| 172 |
-
|
| 173 |
-
result = predict_core(image_array)
|
| 174 |
return JSONResponse(content=result)
|
| 175 |
except Exception as e:
|
| 176 |
logger.error(f"API prediction failed: {str(e)}")
|
|
@@ -182,21 +197,14 @@ async def predict_api(file: UploadFile = File(...)):
|
|
| 182 |
@app.post("/api/predict_raw/")
|
| 183 |
async def predict_raw_api(file: UploadFile = File(...)):
|
| 184 |
try:
|
| 185 |
-
# Validate file type
|
| 186 |
if not file.content_type.startswith('image/'):
|
| 187 |
raise HTTPException(status_code=400, detail="File must be an image")
|
| 188 |
|
| 189 |
-
# Read file contents
|
| 190 |
contents = await file.read()
|
| 191 |
-
|
| 192 |
-
# Convert to PIL Image first
|
| 193 |
image = Image.open(io.BytesIO(contents))
|
| 194 |
image = image.convert('RGB')
|
| 195 |
|
| 196 |
-
|
| 197 |
-
image_array = np.array(image)
|
| 198 |
-
|
| 199 |
-
result = predict_core(image_array, return_raw=True)
|
| 200 |
return JSONResponse(content=result)
|
| 201 |
except Exception as e:
|
| 202 |
logger.error(f"API raw prediction failed: {str(e)}")
|
|
|
|
| 59 |
logger.error(f"Failed to initialize classifier: {str(e)}")
|
| 60 |
raise RuntimeError(f"Failed to initialize classifier: {str(e)}")
|
| 61 |
|
| 62 |
+
# Enhanced image validation and preprocessing
|
| 63 |
+
def validate_and_preprocess_image(image):
|
| 64 |
+
try:
|
| 65 |
+
if image is None:
|
| 66 |
+
return False, "No image provided"
|
| 67 |
+
|
| 68 |
+
# Convert PIL to numpy if needed
|
| 69 |
+
if hasattr(image, 'convert'):
|
| 70 |
+
image = image.convert('RGB')
|
| 71 |
+
image = np.array(image)
|
| 72 |
+
|
| 73 |
+
if not isinstance(image, np.ndarray):
|
| 74 |
+
image = np.array(image)
|
| 75 |
+
|
| 76 |
+
# Basic validation
|
| 77 |
+
if len(image.shape) < 2:
|
| 78 |
+
return False, "Invalid image format"
|
| 79 |
+
|
| 80 |
+
if image.shape[0] < 10 or image.shape[1] < 10:
|
| 81 |
+
return False, "Image too small (minimum 10x10 pixels)"
|
| 82 |
+
|
| 83 |
+
# Ensure 3-channel RGB
|
| 84 |
+
if len(image.shape) == 2:
|
| 85 |
+
image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
|
| 86 |
+
elif len(image.shape) == 3 and image.shape[2] == 4:
|
| 87 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGBA2RGB)
|
| 88 |
+
elif len(image.shape) == 3 and image.shape[2] == 3:
|
| 89 |
+
# Already RGB, but ensure proper format
|
| 90 |
+
pass
|
| 91 |
+
else:
|
| 92 |
+
return False, f"Unsupported image shape: {image.shape}"
|
| 93 |
+
|
| 94 |
+
# Normalize to 0-255 range if needed
|
| 95 |
+
if image.dtype == np.float32 or image.dtype == np.float64:
|
| 96 |
+
if image.max() <= 1.0:
|
| 97 |
+
image = (image * 255).astype(np.uint8)
|
| 98 |
+
|
| 99 |
+
logger.info(f"Preprocessed image shape: {image.shape}, dtype: {image.dtype}")
|
| 100 |
+
return True, image
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
logger.error(f"Image preprocessing failed: {str(e)}")
|
| 104 |
+
return False, f"Image preprocessing failed: {str(e)}"
|
| 105 |
|
| 106 |
# Core prediction function
|
| 107 |
def predict_core(image, return_raw=False):
|
| 108 |
try:
|
| 109 |
+
is_valid, result = validate_and_preprocess_image(image)
|
| 110 |
if not is_valid:
|
| 111 |
logger.error(f"Invalid image: {result}")
|
| 112 |
return {"error": result, "status": "failed"}
|
|
|
|
| 114 |
image = result
|
| 115 |
logger.info(f"Input image shape: {image.shape}, dtype: {image.dtype}")
|
| 116 |
|
| 117 |
+
# Use the classifier's predict method
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
result = classifier.predict(image, return_probs=True)
|
| 119 |
logger.info(f"Raw prediction result: {result}")
|
| 120 |
|
|
|
|
| 136 |
logger.info(f"Prediction output: {output}")
|
| 137 |
return output
|
| 138 |
except Exception as e:
|
| 139 |
+
error_msg = f"Prediction failed: {str(e)}"
|
| 140 |
+
logger.error(f"{error_msg}\n{traceback.format_exc()}")
|
| 141 |
+
return {"error": error_msg, "status": "failed"}
|
| 142 |
|
| 143 |
# UI prediction function
|
| 144 |
def predict(image):
|
| 145 |
if image is None:
|
| 146 |
return "Error: No image provided"
|
| 147 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
result = predict_core(image)
|
| 149 |
if result["status"] == "failed":
|
| 150 |
return f"Error: {result['error']}"
|
| 151 |
+
|
| 152 |
return (
|
| 153 |
f"**Classified as**: {result['prediction']}\n"
|
| 154 |
f"**Confidence**: {result['confidence']:.4f}\n"
|
|
|
|
| 162 |
if image is None:
|
| 163 |
return json.dumps({"error": "No image provided", "status": "failed"}, indent=2)
|
| 164 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
result = predict_core(image, return_raw=True)
|
| 166 |
return json.dumps(result, indent=2)
|
| 167 |
|
|
|
|
| 178 |
@app.post("/api/predict/")
|
| 179 |
async def predict_api(file: UploadFile = File(...)):
|
| 180 |
try:
|
|
|
|
| 181 |
if not file.content_type.startswith('image/'):
|
| 182 |
raise HTTPException(status_code=400, detail="File must be an image")
|
| 183 |
|
|
|
|
| 184 |
contents = await file.read()
|
|
|
|
|
|
|
| 185 |
image = Image.open(io.BytesIO(contents))
|
| 186 |
image = image.convert('RGB')
|
| 187 |
|
| 188 |
+
result = predict_core(image)
|
|
|
|
|
|
|
|
|
|
| 189 |
return JSONResponse(content=result)
|
| 190 |
except Exception as e:
|
| 191 |
logger.error(f"API prediction failed: {str(e)}")
|
|
|
|
| 197 |
@app.post("/api/predict_raw/")
|
| 198 |
async def predict_raw_api(file: UploadFile = File(...)):
|
| 199 |
try:
|
|
|
|
| 200 |
if not file.content_type.startswith('image/'):
|
| 201 |
raise HTTPException(status_code=400, detail="File must be an image")
|
| 202 |
|
|
|
|
| 203 |
contents = await file.read()
|
|
|
|
|
|
|
| 204 |
image = Image.open(io.BytesIO(contents))
|
| 205 |
image = image.convert('RGB')
|
| 206 |
|
| 207 |
+
result = predict_core(image, return_raw=True)
|
|
|
|
|
|
|
|
|
|
| 208 |
return JSONResponse(content=result)
|
| 209 |
except Exception as e:
|
| 210 |
logger.error(f"API raw prediction failed: {str(e)}")
|