File size: 3,550 Bytes
5c33c53
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import gradio as gr
from models.birefnet import BiRefNet
import io
import tempfile  # Импортируем tempfile

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Используем устройство:", device)

# Конфигурация моделей
MODEL_CONFIG = {
    "BiRefNet_HR": {
        "repo": "ZhengPeng7/BiRefNet_HR",
        "image_size": (2048, 2048)
    },
    "BiRefNet": {
        "repo": "ZhengPeng7/BiRefNet",
        "image_size": (1024, 1024)
    },
    "BiRefNet-matting": {
        "repo": "ZhengPeng7/BiRefNet-matting",
        "image_size": (1024, 1024)
    },
        "BiRefNet-portrait": {
        "repo": "ZhengPeng7/BiRefNet-portrait",
        "image_size": (1024, 1024)
    },
        "BiRefNet-HRSOD": {
        "repo": "ZhengPeng7/BiRefNet-HRSOD",
        "image_size": (1024, 1024)
    },
}

# Кэш для загруженных моделей
loaded_models = {}

def load_model(model_name):
    if model_name not in loaded_models:
        print(f"Загрузка модели {model_name}...")
        model = BiRefNet.from_pretrained(MODEL_CONFIG[model_name]["repo"])
        model.to(device).eval()
        if device == 'cuda':
            model.half()  # FP16 для CUDA
        loaded_models[model_name] = model
    return loaded_models[model_name]

def extract_object(image, model_name):
    # Загрузка выбранной модели
    model = load_model(model_name)
    config = MODEL_CONFIG[model_name]

    # Преобразование изображения
    transform = transforms.Compose([
        transforms.Resize(config["image_size"]),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    input_tensor = transform(image).unsqueeze(0)
    input_tensor = input_tensor.to(device)
    if device == 'cuda':
        input_tensor = input_tensor.half()

    with torch.no_grad():
        preds = model(input_tensor)[-1].sigmoid().cpu()

    mask = transforms.ToPILImage()(preds[0].squeeze())
    mask = mask.resize(image.size)

    result = image.copy()
    result.putalpha(mask)

    # Save as PNG to a temporary file
    with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
        result.save(tmp_file)
        temp_filepath = tmp_file.name # Get the temporary file path

    return temp_filepath  # Return the filepath to Gradio

# Создаем интерфейс с выбором модели
iface = gr.Interface(
    fn=extract_object,
    inputs=[
        gr.Image(type="pil", label="Входное изображение"),
        gr.Dropdown(
            choices=list(MODEL_CONFIG.keys()),
            value="BiRefNet_HR",
            label="Выбор модели"
        )
    ],
    outputs=gr.Image(type="filepath", label="Результат"), # Output type is filepath
    title="BiRefNet - Интерактивная сегментация",
    description=(
        "Выберите модель и загрузите изображение для сегментации. "
        "Доступные модели: BiRefNet_HR (2048x2048), BiRefNet (1024x1024), BiRefNet-lite-2K (2048x2048)"
    ),
    allow_flagging="never" # Disable flagging, as it is not needed and causing issues
)

if __name__ == "__main__":
    iface.launch(share=True)