sharktide commited on
Commit
5e08f68
·
verified ·
1 Parent(s): edecd42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -23
app.py CHANGED
@@ -2,25 +2,33 @@ from fastapi import FastAPI, File, UploadFile, Request
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
- import cv2 # Make sure OpenCV is available for resizing
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
- #from slowapi import Limiter
9
- #from slowapi import Limiter
10
- #from slowapi.util import get_remote_address
11
- #from slowapi.errors import RateLimitExceeded
12
 
13
  # Load your trained model
14
  model = tf.keras.models.load_model('recyclebot.keras')
15
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  # Define class names for predictions (this should be the same as in your local code)
17
  CLASSES = ['Glass', 'Metal', 'Paperboard', 'Plastic-Polystyrene', 'Plastic-Regular']
18
 
19
  # Create FastAPI app
20
  app = FastAPI()
21
 
22
- #limiter = Limiter(key_func=get_remote_address)
23
-
24
  app.add_middleware(
25
  CORSMiddleware,
26
  allow_origins=["*"], # Allow all origins (or specify specific origins)
@@ -29,10 +37,7 @@ app.add_middleware(
29
  allow_headers=["*"], # Allow all headers
30
  )
31
 
32
-
33
-
34
-
35
- # Preprocessing the image (resize, reshape without normalization)
36
  def preprocess_image(image_file):
37
  # Load image using PIL
38
  image = Image.open(image_file)
@@ -48,22 +53,25 @@ def preprocess_image(image_file):
48
 
49
  return image
50
 
51
-
 
 
 
 
 
 
 
 
 
 
52
 
53
  @app.post("/predict")
54
- #@limiter.limit("10/minute")
55
- async def predict(file: UploadFile = File(...)): #async def predict(request: Request, file: UploadFile = File(...)):
56
  try:
57
  img_array = preprocess_image(file.file) # Preprocess the image
58
  prediction1 = model.predict(img_array) # Get predictions
59
-
60
- weight_1 = 0.6
61
- weight_2 = 0.4
62
 
63
- # Get the index of the highest probability class (like np.argmax on local machine)
64
  predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
65
-
66
- # Map the predicted index to the class name (like final_class = CLASSES[np.argmax(final_preds)])
67
  predicted_class = CLASSES[predicted_class_idx] # Convert to class name
68
 
69
  return JSONResponse(content={"prediction": predicted_class})
@@ -71,15 +79,36 @@ async def predict(file: UploadFile = File(...)): #async def predict(request: R
71
  except Exception as e:
72
  return JSONResponse(content={"error": str(e)}, status_code=400)
73
 
 
 
 
 
 
 
74
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  @app.get("/working")
77
  async def working():
78
  return JSONResponse(content={"Status": "Working"})
79
 
80
-
81
-
82
- #To manually run FastAPI (though Hugging Face will typically do this)
83
  if __name__ == "__main__":
84
  import uvicorn
85
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
2
  import tensorflow as tf
3
  import numpy as np
4
  from PIL import Image
5
+ import cv2
6
  from fastapi.responses import JSONResponse
7
  from fastapi.middleware.cors import CORSMiddleware
8
+ from torchvision import transforms
9
+ from transformers import AutoModelForImageSegmentation
 
 
10
 
11
  # Load your trained model
12
  model = tf.keras.models.load_model('recyclebot.keras')
13
 
14
+ # Load background removal model
15
+ birefnet = AutoModelForImageSegmentation.from_pretrained(
16
+ "ZhengPeng7/BiRefNet", trust_remote_code=True
17
+ )
18
+
19
+ # Transform for the background removal model
20
+ transform_image = transforms.Compose([
21
+ transforms.Resize((1024, 1024)),
22
+ transforms.ToTensor(),
23
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
24
+ ])
25
+
26
  # Define class names for predictions (this should be the same as in your local code)
27
  CLASSES = ['Glass', 'Metal', 'Paperboard', 'Plastic-Polystyrene', 'Plastic-Regular']
28
 
29
  # Create FastAPI app
30
  app = FastAPI()
31
 
 
 
32
  app.add_middleware(
33
  CORSMiddleware,
34
  allow_origins=["*"], # Allow all origins (or specify specific origins)
 
37
  allow_headers=["*"], # Allow all headers
38
  )
39
 
40
+ # Preprocess the image (resize, reshape without normalization)
 
 
 
41
  def preprocess_image(image_file):
42
  # Load image using PIL
43
  image = Image.open(image_file)
 
53
 
54
  return image
55
 
56
+ # Background removal function
57
+ def remove_background(image):
58
+ image_size = image.size
59
+ input_images = transform_image(image).unsqueeze(0)
60
+ with torch.no_grad():
61
+ preds = birefnet(input_images)[-1].sigmoid()
62
+ pred = preds[0].squeeze()
63
+ pred_pil = transforms.ToPILImage()(pred)
64
+ mask = pred_pil.resize(image_size)
65
+ image.putalpha(mask)
66
+ return image
67
 
68
  @app.post("/predict")
69
+ async def predict(file: UploadFile = File(...)):
 
70
  try:
71
  img_array = preprocess_image(file.file) # Preprocess the image
72
  prediction1 = model.predict(img_array) # Get predictions
 
 
 
73
 
 
74
  predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
 
 
75
  predicted_class = CLASSES[predicted_class_idx] # Convert to class name
76
 
77
  return JSONResponse(content={"prediction": predicted_class})
 
79
  except Exception as e:
80
  return JSONResponse(content={"error": str(e)}, status_code=400)
81
 
82
+ @app.post("/predict/recyclebot0accuracy")
83
+ async def predict_recyclebot0accuracy(file: UploadFile = File(...)):
84
+ try:
85
+ # Load and remove background from image
86
+ image = Image.open(file.file).convert("RGB")
87
+ image = remove_background(image)
88
 
89
+ # Save the image with a transparent background (to use in further processing)
90
+ image_path = "processed_image.jpg"
91
+ image.save(image_path, "JPEG")
92
+
93
+ # Preprocess the image with the background removed
94
+ img_array = preprocess_image(image_path)
95
+
96
+ # Get predictions
97
+ prediction1 = model.predict(img_array)
98
+
99
+ predicted_class_idx = np.argmax(prediction1, axis=1)[0] # Get predicted class index
100
+ predicted_class = CLASSES[predicted_class_idx] # Convert to class name
101
+
102
+ return JSONResponse(content={"prediction": predicted_class})
103
+
104
+ except Exception as e:
105
+ return JSONResponse(content={"error": str(e)}, status_code=400)
106
 
107
  @app.get("/working")
108
  async def working():
109
  return JSONResponse(content={"Status": "Working"})
110
 
111
+ # To manually run FastAPI
 
 
112
  if __name__ == "__main__":
113
  import uvicorn
114
  uvicorn.run(app, host="0.0.0.0", port=7860)