Umargenerating commited on
Commit
5c33c53
·
verified ·
1 Parent(s): ae8e25e

Upload run_gui.py

Browse files
Files changed (1) hide show
  1. run_gui.py +104 -0
run_gui.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ import matplotlib.pyplot as plt
4
+ from torchvision import transforms
5
+ import gradio as gr
6
+ from models.birefnet import BiRefNet
7
+ import io
8
+ import tempfile # Импортируем tempfile
9
+
10
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
11
+ print("Используем устройство:", device)
12
+
13
+ # Конфигурация моделей
14
+ MODEL_CONFIG = {
15
+ "BiRefNet_HR": {
16
+ "repo": "ZhengPeng7/BiRefNet_HR",
17
+ "image_size": (2048, 2048)
18
+ },
19
+ "BiRefNet": {
20
+ "repo": "ZhengPeng7/BiRefNet",
21
+ "image_size": (1024, 1024)
22
+ },
23
+ "BiRefNet-matting": {
24
+ "repo": "ZhengPeng7/BiRefNet-matting",
25
+ "image_size": (1024, 1024)
26
+ },
27
+ "BiRefNet-portrait": {
28
+ "repo": "ZhengPeng7/BiRefNet-portrait",
29
+ "image_size": (1024, 1024)
30
+ },
31
+ "BiRefNet-HRSOD": {
32
+ "repo": "ZhengPeng7/BiRefNet-HRSOD",
33
+ "image_size": (1024, 1024)
34
+ },
35
+ }
36
+
37
+ # Кэш для загруженных моделей
38
+ loaded_models = {}
39
+
40
+ def load_model(model_name):
41
+ if model_name not in loaded_models:
42
+ print(f"Загрузка модели {model_name}...")
43
+ model = BiRefNet.from_pretrained(MODEL_CONFIG[model_name]["repo"])
44
+ model.to(device).eval()
45
+ if device == 'cuda':
46
+ model.half() # FP16 для CUDA
47
+ loaded_models[model_name] = model
48
+ return loaded_models[model_name]
49
+
50
+ def extract_object(image, model_name):
51
+ # Загрузка выбранной модели
52
+ model = load_model(model_name)
53
+ config = MODEL_CONFIG[model_name]
54
+
55
+ # Преобразование изображения
56
+ transform = transforms.Compose([
57
+ transforms.Resize(config["image_size"]),
58
+ transforms.ToTensor(),
59
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
60
+ ])
61
+
62
+ input_tensor = transform(image).unsqueeze(0)
63
+ input_tensor = input_tensor.to(device)
64
+ if device == 'cuda':
65
+ input_tensor = input_tensor.half()
66
+
67
+ with torch.no_grad():
68
+ preds = model(input_tensor)[-1].sigmoid().cpu()
69
+
70
+ mask = transforms.ToPILImage()(preds[0].squeeze())
71
+ mask = mask.resize(image.size)
72
+
73
+ result = image.copy()
74
+ result.putalpha(mask)
75
+
76
+ # Save as PNG to a temporary file
77
+ with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
78
+ result.save(tmp_file)
79
+ temp_filepath = tmp_file.name # Get the temporary file path
80
+
81
+ return temp_filepath # Return the filepath to Gradio
82
+
83
+ # Создаем интерфейс с выбором модели
84
+ iface = gr.Interface(
85
+ fn=extract_object,
86
+ inputs=[
87
+ gr.Image(type="pil", label="Входное изображение"),
88
+ gr.Dropdown(
89
+ choices=list(MODEL_CONFIG.keys()),
90
+ value="BiRefNet_HR",
91
+ label="Выбор модели"
92
+ )
93
+ ],
94
+ outputs=gr.Image(type="filepath", label="Результат"), # Output type is filepath
95
+ title="BiRefNet - Интерактивная сегментация",
96
+ description=(
97
+ "Выберите модель и загрузите изображение для сегментации. "
98
+ "Доступные модели: BiRefNet_HR (2048x2048), BiRefNet (1024x1024), BiRefNet-lite-2K (2048x2048)"
99
+ ),
100
+ allow_flagging="never" # Disable flagging, as it is not needed and causing issues
101
+ )
102
+
103
+ if __name__ == "__main__":
104
+ iface.launch(share=True)