sharktide commited on
Commit
76cc3b5
·
verified ·
1 Parent(s): ac865bb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -22
app.py CHANGED
@@ -7,6 +7,11 @@ from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from torchvision import transforms
9
  from transformers import AutoModelForImageSegmentation
 
 
 
 
 
10
 
11
  # Load your trained model
12
  model = tf.keras.models.load_model('recyclebot.keras')
@@ -39,35 +44,44 @@ app.add_middleware(
39
 
40
  # Preprocess the image (resize, reshape without normalization)
41
  def preprocess_image(image_file):
42
- # Load image using PIL
43
- image = Image.open(image_file)
44
-
45
- # Convert image to numpy array
46
- image = np.array(image)
47
-
48
- # Resize to the input shape expected by the model
49
- image = cv2.resize(image, (240, 240)) # Resize image to match model input
50
-
51
- # Reshape the image (similar to your local code)
52
- image = image.reshape(-1, 240, 240, 3) # Add the batch dimension for inference
53
-
54
- return image
 
 
 
 
55
 
56
  # Background removal function
57
  def remove_background(image):
58
- image_size = image.size
59
- input_images = transform_image(image).unsqueeze(0)
60
- with torch.no_grad():
61
- preds = birefnet(input_images)[-1].sigmoid()
62
- pred = preds[0].squeeze()
63
- pred_pil = transforms.ToPILImage()(pred)
64
- mask = pred_pil.resize(image_size)
65
- image.putalpha(mask)
66
- return image
 
 
 
 
67
 
68
  @app.post("/predict")
69
  async def predict(file: UploadFile = File(...)):
70
  try:
 
71
  img_array = preprocess_image(file.file) # Preprocess the image
72
  prediction1 = model.predict(img_array) # Get predictions
73
 
@@ -77,11 +91,13 @@ async def predict(file: UploadFile = File(...)):
77
  return JSONResponse(content={"prediction": predicted_class})
78
 
79
  except Exception as e:
 
80
  return JSONResponse(content={"error": str(e)}, status_code=400)
81
 
82
  @app.post("/predict/recyclebot0accuracy")
83
  async def predict_recyclebot0accuracy(file: UploadFile = File(...)):
84
  try:
 
85
  # Load and remove background from image
86
  image = Image.open(file.file).convert("RGB")
87
  image = remove_background(image)
@@ -102,6 +118,7 @@ async def predict_recyclebot0accuracy(file: UploadFile = File(...)):
102
  return JSONResponse(content={"prediction": predicted_class})
103
 
104
  except Exception as e:
 
105
  return JSONResponse(content={"error": str(e)}, status_code=400)
106
 
107
  @app.get("/working")
 
7
  from fastapi.middleware.cors import CORSMiddleware
8
  from torchvision import transforms
9
  from transformers import AutoModelForImageSegmentation
10
+ import logging
11
+
12
+ # Set up logging
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
 
16
  # Load your trained model
17
  model = tf.keras.models.load_model('recyclebot.keras')
 
44
 
45
  # Preprocess the image (resize, reshape without normalization)
46
  def preprocess_image(image_file):
47
+ try:
48
+ # Load image using PIL
49
+ image = Image.open(image_file)
50
+
51
+ # Convert image to numpy array
52
+ image = np.array(image)
53
+
54
+ # Resize to the input shape expected by the model
55
+ image = cv2.resize(image, (240, 240)) # Resize image to match model input
56
+
57
+ # Reshape the image (similar to your local code)
58
+ image = image.reshape(-1, 240, 240, 3) # Add the batch dimension for inference
59
+
60
+ return image
61
+ except Exception as e:
62
+ logger.error(f"Error in preprocess_image: {str(e)}")
63
+ raise
64
 
65
  # Background removal function
66
  def remove_background(image):
67
+ try:
68
+ image_size = image.size
69
+ input_images = transform_image(image).unsqueeze(0)
70
+ with torch.no_grad():
71
+ preds = birefnet(input_images)[-1].sigmoid()
72
+ pred = preds[0].squeeze()
73
+ pred_pil = transforms.ToPILImage()(pred)
74
+ mask = pred_pil.resize(image_size)
75
+ image.putalpha(mask)
76
+ return image
77
+ except Exception as e:
78
+ logger.error(f"Error in remove_background: {str(e)}")
79
+ raise
80
 
81
  @app.post("/predict")
82
  async def predict(file: UploadFile = File(...)):
83
  try:
84
+ logger.info("Received request for /predict")
85
  img_array = preprocess_image(file.file) # Preprocess the image
86
  prediction1 = model.predict(img_array) # Get predictions
87
 
 
91
  return JSONResponse(content={"prediction": predicted_class})
92
 
93
  except Exception as e:
94
+ logger.error(f"Error in /predict: {str(e)}")
95
  return JSONResponse(content={"error": str(e)}, status_code=400)
96
 
97
  @app.post("/predict/recyclebot0accuracy")
98
  async def predict_recyclebot0accuracy(file: UploadFile = File(...)):
99
  try:
100
+ logger.info("Received request for /predict/recyclebot0accuracy")
101
  # Load and remove background from image
102
  image = Image.open(file.file).convert("RGB")
103
  image = remove_background(image)
 
118
  return JSONResponse(content={"prediction": predicted_class})
119
 
120
  except Exception as e:
121
+ logger.error(f"Error in /predict/recyclebot0accuracy: {str(e)}")
122
  return JSONResponse(content={"error": str(e)}, status_code=400)
123
 
124
  @app.get("/working")