Janeka commited on
Commit
494a7a7
·
verified ·
1 Parent(s): 742f506

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -87
app.py CHANGED
@@ -1,26 +1,18 @@
1
- import gradio as gr
 
2
  import numpy as np
3
  from PIL import Image
4
- import torch
5
- from transformers import pipeline
6
- from functools import lru_cache
7
  import cv2
 
8
  import logging
9
 
 
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
- # Cache models to avoid reloading on every request
15
- @lru_cache(maxsize=1)
16
- def load_model(model_name):
17
- try:
18
- return pipeline("image-segmentation", model_name)
19
- except Exception as e:
20
- logger.error(f"Failed to load {model_name}: {e}")
21
- return None
22
-
23
- # Model sequence configuration
24
  MODELS = [
25
  {"name": "BRIA", "repo": "BRIA-AI/bria-rmbg", "weight": 1.0},
26
  {"name": "INSPyReNet", "repo": "mattmdjaga/INSPyReNet", "weight": 0.9},
@@ -30,91 +22,64 @@ MODELS = [
30
  {"name": "ISNet-Anime", "repo": "skytnt/anime-seg", "weight": 0.5}
31
  ]
32
 
33
- def process_single_model(image, model):
34
- """Process image with a single model"""
35
  try:
36
- pipe = load_model(model["repo"])
37
- if pipe is None:
38
- return None
39
-
40
- # Convert image to numpy array if needed
41
- if isinstance(image, Image.Image):
42
- image_np = np.array(image)
43
- else:
44
- image_np = image
45
-
46
- result = pipe(image_np)
47
- return result['mask'] if isinstance(result, dict) else result[0]['mask']
48
  except Exception as e:
49
- logger.warning(f"{model['name']} failed: {e}")
50
  return None
51
 
52
- def combine_masks(masks, weights):
53
- """Combine masks with weighted averaging"""
54
- valid_masks = [m for m in masks if m is not None]
55
- if not valid_masks:
56
- return None
57
-
58
- total_weight = sum(w for w, m in zip(weights, masks) if m is not None)
59
- combined = np.zeros_like(valid_masks[0], dtype=np.float32)
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  for mask, weight in zip(masks, weights):
62
- if mask is not None:
63
- combined += (mask.astype(np.float32) * weight
64
 
65
- return (combined / total_weight).astype(np.uint8)
66
 
67
- def remove_background(image):
68
- """Main processing pipeline"""
69
  try:
70
- # Convert input to PIL Image
71
- if isinstance(image, np.ndarray):
72
- image = Image.fromarray(image)
73
-
74
- # Process through all models
75
- masks = []
76
- for model in MODELS:
77
- mask = process_single_model(image, model)
78
- masks.append(mask)
79
-
80
- # Combine results
81
- weights = [m["weight"] for m in MODELS]
82
- final_mask = combine_masks(masks, weights)
83
 
84
- if final_mask is None:
85
- raise ValueError("All models failed")
86
-
87
  # Apply mask
88
  background = Image.new('RGB', image.size, (0, 0, 0))
89
- final_image = Image.composite(image, background, Image.fromarray(final_mask))
90
 
91
- return final_image
92
- except Exception as e:
93
- logger.error(f"Processing failed: {e}")
94
- return None
95
-
96
- # Gradio interface with API endpoint
97
- with gr.Blocks() as app:
98
- gr.Markdown("## 🖼️ Advanced Background Remover")
99
- with gr.Row():
100
- with gr.Column():
101
- input_image = gr.Image(label="Upload Image")
102
- submit_btn = gr.Button("Remove Background")
103
- with gr.Column():
104
- output_image = gr.Image(label="Result")
105
 
106
- submit_btn.click(
107
- fn=remove_background,
108
- inputs=input_image,
109
- outputs=output_image
110
- )
111
-
112
- # API endpoint for mobile apps
113
- app.api_app = gr.routes.App.create_app(
114
- fn=remove_background,
115
- inputs=gr.Image(),
116
- outputs=gr.Image()
117
- )
118
 
119
- if __name__ == "__main__":
120
- app.launch(server_name="0.0.0.0", server_port=7860)
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from fastapi.responses import Response
3
  import numpy as np
4
  from PIL import Image
5
+ import io
 
 
6
  import cv2
7
+ from transformers import pipeline
8
  import logging
9
 
10
+ app = FastAPI(title="Advanced Background Remover")
11
+
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO)
14
  logger = logging.getLogger(__name__)
15
 
 
 
 
 
 
 
 
 
 
 
16
  MODELS = [
17
  {"name": "BRIA", "repo": "BRIA-AI/bria-rmbg", "weight": 1.0},
18
  {"name": "INSPyReNet", "repo": "mattmdjaga/INSPyReNet", "weight": 0.9},
 
22
  {"name": "ISNet-Anime", "repo": "skytnt/anime-seg", "weight": 0.5}
23
  ]
24
 
25
+ def load_model(model_repo):
 
26
  try:
27
+ return pipeline("image-segmentation", model_repo)
 
 
 
 
 
 
 
 
 
 
 
28
  except Exception as e:
29
+ logger.error(f"Failed to load {model_repo}: {e}")
30
  return None
31
 
32
+ def process_image(image: np.ndarray):
33
+ masks = []
34
+ weights = []
 
 
 
 
 
35
 
36
+ for model in MODELS:
37
+ pipe = load_model(model["repo"])
38
+ if pipe:
39
+ try:
40
+ result = pipe(image)
41
+ mask = result[0]['mask'] if isinstance(result, list) else result['mask']
42
+ masks.append(mask)
43
+ weights.append(model["weight"])
44
+ except Exception as e:
45
+ logger.warning(f"{model['name']} failed: {e}")
46
+
47
+ if not masks:
48
+ raise ValueError("All models failed")
49
+
50
+ # Weighted average of masks
51
+ total_weight = sum(weights)
52
+ combined = np.zeros_like(masks[0], dtype=np.float32)
53
  for mask, weight in zip(masks, weights):
54
+ combined += mask.astype(np.float32) * weight
55
+ final_mask = (combined / total_weight).astype(np.uint8)
56
 
57
+ return final_mask
58
 
59
+ @app.post("/remove-background")
60
+ async def remove_background(file: UploadFile = File(...)):
61
  try:
62
+ # Read and convert image
63
+ contents = await file.read()
64
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
65
+ image_np = np.array(image)
66
+
67
+ # Process image
68
+ mask = process_image(image_np)
 
 
 
 
 
 
69
 
 
 
 
70
  # Apply mask
71
  background = Image.new('RGB', image.size, (0, 0, 0))
72
+ result = Image.composite(image, background, Image.fromarray(mask))
73
 
74
+ # Return result
75
+ img_byte_arr = io.BytesIO()
76
+ result.save(img_byte_arr, format='PNG')
77
+ return Response(content=img_byte_arr.getvalue(), media_type="image/png")
 
 
 
 
 
 
 
 
 
 
78
 
79
+ except Exception as e:
80
+ logger.error(f"Error: {e}")
81
+ return {"error": str(e)}, 500
 
 
 
 
 
 
 
 
 
82
 
83
+ @app.get("/")
84
+ def health_check():
85
+ return {"status": "healthy", "models": [m["name"] for m in MODELS]}