Janeka commited on
Commit
c2c25ca
·
verified ·
1 Parent(s): a54a912

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -56
app.py CHANGED
@@ -1,18 +1,10 @@
1
- from fastapi import FastAPI, UploadFile, File
2
- from fastapi.responses import Response
3
  import numpy as np
4
  from PIL import Image
5
- import io
6
- import cv2
7
  from transformers import pipeline
8
- import logging
9
-
10
- app = FastAPI(title="Advanced Background Remover")
11
-
12
- # Configure logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
 
 
16
  MODELS = [
17
  {"name": "BRIA", "repo": "BRIA-AI/bria-rmbg", "weight": 1.0},
18
  {"name": "INSPyReNet", "repo": "mattmdjaga/INSPyReNet", "weight": 0.9},
@@ -23,63 +15,55 @@ MODELS = [
23
  ]
24
 
25
  def load_model(model_repo):
26
- try:
27
- return pipeline("image-segmentation", model_repo)
28
- except Exception as e:
29
- logger.error(f"Failed to load {model_repo}: {e}")
30
- return None
31
 
32
- def process_image(image: np.ndarray):
 
 
 
 
33
  masks = []
34
  weights = []
35
 
36
  for model in MODELS:
37
- pipe = load_model(model["repo"])
38
- if pipe:
39
- try:
40
- result = pipe(image)
41
- mask = result[0]['mask'] if isinstance(result, list) else result['mask']
42
- masks.append(mask)
43
- weights.append(model["weight"])
44
- except Exception as e:
45
- logger.warning(f"{model['name']} failed: {e}")
 
46
 
47
  if not masks:
48
- raise ValueError("All models failed")
49
 
50
  # Weighted average of masks
51
- total_weight = sum(weights)
52
  combined = np.zeros_like(masks[0], dtype=np.float32)
53
  for mask, weight in zip(masks, weights):
54
  combined += mask.astype(np.float32) * weight
55
- final_mask = (combined / total_weight).astype(np.uint8)
56
 
57
- return final_mask
58
-
59
- @app.post("/remove-background")
60
- async def remove_background(file: UploadFile = File(...)):
61
- try:
62
- # Read and convert image
63
- contents = await file.read()
64
- image = Image.open(io.BytesIO(contents)).convert("RGB")
65
- image_np = np.array(image)
66
-
67
- # Process image
68
- mask = process_image(image_np)
69
-
70
- # Apply mask
71
- background = Image.new('RGB', image.size, (0, 0, 0))
72
- result = Image.composite(image, background, Image.fromarray(mask))
73
-
74
- # Return result
75
- img_byte_arr = io.BytesIO()
76
- result.save(img_byte_arr, format='PNG')
77
- return Response(content=img_byte_arr.getvalue(), media_type="image/png")
78
 
79
- except Exception as e:
80
- logger.error(f"Error: {e}")
81
- return {"error": str(e)}, 500
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- @app.get("/")
84
- def health_check():
85
- return {"status": "healthy", "models": [m["name"] for m in MODELS]}
 
1
+ import gradio as gr
 
2
  import numpy as np
3
  from PIL import Image
 
 
4
  from transformers import pipeline
5
+ import cv2
 
 
 
 
 
 
6
 
7
+ # Model sequence with weights
8
  MODELS = [
9
  {"name": "BRIA", "repo": "BRIA-AI/bria-rmbg", "weight": 1.0},
10
  {"name": "INSPyReNet", "repo": "mattmdjaga/INSPyReNet", "weight": 0.9},
 
15
  ]
16
 
17
  def load_model(model_repo):
18
+ return pipeline("image-segmentation", model_repo)
 
 
 
 
19
 
20
+ def process_image(input_image):
21
+ # Convert Gradio input to PIL Image
22
+ if isinstance(input_image, np.ndarray):
23
+ input_image = Image.fromarray(input_image)
24
+
25
  masks = []
26
  weights = []
27
 
28
  for model in MODELS:
29
+ try:
30
+ pipe = load_model(model["repo"])
31
+ result = pipe(np.array(input_image))
32
+ mask = result[0]['mask'] if isinstance(result, list) else result['mask']
33
+ masks.append(mask)
34
+ weights.append(model["weight"])
35
+ print(f"{model['name']} completed successfully") # Debug print
36
+ except Exception as e:
37
+ print(f"{model['name']} failed: {str(e)}") # Debug print
38
+ continue
39
 
40
  if not masks:
41
+ return None
42
 
43
  # Weighted average of masks
 
44
  combined = np.zeros_like(masks[0], dtype=np.float32)
45
  for mask, weight in zip(masks, weights):
46
  combined += mask.astype(np.float32) * weight
47
+ final_mask = (combined / sum(weights)).astype(np.uint8)
48
 
49
+ # Create transparent background
50
+ result = input_image.copy()
51
+ result.putalpha(Image.fromarray(final_mask))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ return result
54
+
55
+ # Gradio interface
56
+ demo = gr.Interface(
57
+ fn=process_image,
58
+ inputs=gr.Image(label="Input Image"),
59
+ outputs=gr.Image(label="Result (PNG with Transparency)"),
60
+ title="🎨 Advanced Background Remover",
61
+ description="Combines 6 AI models for perfect background removal",
62
+ examples=[
63
+ ["example1.jpg"],
64
+ ["example2.jpg"],
65
+ ["example3.png"]
66
+ ]
67
+ )
68
 
69
+ demo.launch(server_name="0.0.0.0", server_port=7860)