EnginDev commited on
Commit
cb3707d
·
verified ·
1 Parent(s): acbf156

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -61
app.py CHANGED
@@ -1,63 +1,52 @@
1
  import gradio as gr
2
- from transformers import SamProcessor, SamModel
3
- from PIL import Image
4
- import torch
5
  import numpy as np
6
- import random
7
- import traceback
8
-
9
- # Modell laden
10
- model_id = "facebook/sam-vit-base"
11
- processor = SamProcessor.from_pretrained(model_id)
12
- model = SamModel.from_pretrained(model_id)
13
-
14
- def random_color():
15
- """Zufällige RGB-Farbe"""
16
- return [random.randint(0, 255) for _ in range(3)]
17
-
18
- def segment_image(image):
19
- try:
20
- device = torch.device("cpu")
21
- model.to(device)
22
-
23
- inputs = processor(images=image, return_tensors="pt").to(device)
24
-
25
- with torch.no_grad():
26
- outputs = model(**inputs)
27
-
28
- masks = processor.post_process_masks(
29
- outputs.pred_masks.cpu(),
30
- inputs["original_sizes"].cpu(),
31
- inputs["reshaped_input_sizes"].cpu()
32
- )
33
-
34
- mask_arrays = masks[0].numpy()
35
- img_array = np.array(image)
36
- overlay = np.zeros_like(img_array, dtype=np.uint8)
37
-
38
- # Jede Maske farbig einfärben
39
- for mask in mask_arrays:
40
- mask = mask[0]
41
- color = random_color()
42
- for c in range(3):
43
- overlay[:, :, c] = np.where(mask > 0.5, color[c], overlay[:, :, c])
44
-
45
- # Stärkere Farbmischung (80 % Maske / 20 % Original)
46
- blended = Image.fromarray(
47
- (0.2 * img_array + 0.8 * overlay).astype(np.uint8)
48
- )
49
-
50
- return blended
51
-
52
- except Exception:
53
- return f"Fehler:\n{traceback.format_exc()}"
54
-
55
- demo = gr.Interface(
56
- fn=segment_image,
57
- inputs=gr.Image(type="pil", label="Upload your fish image"),
58
- outputs=gr.Image(type="pil", label="Segmented Output"),
59
- title="FishBoost – Colorful SAM Segmentation (Enhanced Colors)",
60
- description="Erzeugt kräftige, farbige Masken mit Meta SAM (CPU-Version)."
61
- )
62
-
63
- demo.launch()
 
1
  import gradio as gr
 
 
 
2
  import numpy as np
3
+ import torch
4
+ import cv2
5
+ from segment_anything import sam_model_registry, SamAutomaticMaskGenerator
6
+ from PIL import Image
7
+
8
+
9
+ import os
10
+ import urllib.request
11
+
12
+ MODEL_URL = "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth"
13
+ MODEL_PATH = "sam_vit_b.pth"
14
+
15
+ # Eğer model yoksa indir
16
+ if not os.path.exists(MODEL_PATH):
17
+ print("Model indiriliyor...")
18
+ urllib.request.urlretrieve(MODEL_URL, MODEL_PATH)
19
+ print("Model indirildi.")
20
+
21
+ # Model yükle
22
+ model_type = "vit_b"
23
+ device = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
+ sam = sam_model_registry[model_type](checkpoint=MODEL_PATH)
26
+ sam.to(device=device)
27
+
28
+ mask_generator = SamAutomaticMaskGenerator(sam)
29
+
30
+ def segment_all_objects(image):
31
+ image_np = np.array(image)
32
+ masks = mask_generator.generate(image_np)
33
+
34
+ # Maske üzerine çiz
35
+ overlay = image_np.copy()
36
+ for i, mask in enumerate(masks):
37
+ m = mask["segmentation"]
38
+ color = np.random.randint(0, 255, size=(3,))
39
+ overlay[m] = overlay[m] * 0.3 + color * 0.7
40
+ # Maske üstüne label yaz
41
+ y, x = np.where(m)
42
+ if len(x) > 0 and len(y) > 0:
43
+ cx, cy = int(np.mean(x)), int(np.mean(y))
44
+ cv2.putText(overlay, f"Obj {i+1}", (cx, cy), cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255,255,255), 2)
45
+
46
+ return Image.fromarray(overlay.astype(np.uint8))
47
+
48
+ gr.Interface(
49
+ fn=segment_all_objects,
50
+ inputs=gr.Image(type="pil"),
51
+ outputs=gr.Image()
52
+ ).launch()