File size: 3,550 Bytes
5c33c53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import torch
from PIL import Image
import matplotlib.pyplot as plt
from torchvision import transforms
import gradio as gr
from models.birefnet import BiRefNet
import io
import tempfile # Импортируем tempfile
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Используем устройство:", device)
# Конфигурация моделей
MODEL_CONFIG = {
"BiRefNet_HR": {
"repo": "ZhengPeng7/BiRefNet_HR",
"image_size": (2048, 2048)
},
"BiRefNet": {
"repo": "ZhengPeng7/BiRefNet",
"image_size": (1024, 1024)
},
"BiRefNet-matting": {
"repo": "ZhengPeng7/BiRefNet-matting",
"image_size": (1024, 1024)
},
"BiRefNet-portrait": {
"repo": "ZhengPeng7/BiRefNet-portrait",
"image_size": (1024, 1024)
},
"BiRefNet-HRSOD": {
"repo": "ZhengPeng7/BiRefNet-HRSOD",
"image_size": (1024, 1024)
},
}
# Кэш для загруженных моделей
loaded_models = {}
def load_model(model_name):
if model_name not in loaded_models:
print(f"Загрузка модели {model_name}...")
model = BiRefNet.from_pretrained(MODEL_CONFIG[model_name]["repo"])
model.to(device).eval()
if device == 'cuda':
model.half() # FP16 для CUDA
loaded_models[model_name] = model
return loaded_models[model_name]
def extract_object(image, model_name):
# Загрузка выбранной модели
model = load_model(model_name)
config = MODEL_CONFIG[model_name]
# Преобразование изображения
transform = transforms.Compose([
transforms.Resize(config["image_size"]),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_tensor = transform(image).unsqueeze(0)
input_tensor = input_tensor.to(device)
if device == 'cuda':
input_tensor = input_tensor.half()
with torch.no_grad():
preds = model(input_tensor)[-1].sigmoid().cpu()
mask = transforms.ToPILImage()(preds[0].squeeze())
mask = mask.resize(image.size)
result = image.copy()
result.putalpha(mask)
# Save as PNG to a temporary file
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file:
result.save(tmp_file)
temp_filepath = tmp_file.name # Get the temporary file path
return temp_filepath # Return the filepath to Gradio
# Создаем интерфейс с выбором модели
iface = gr.Interface(
fn=extract_object,
inputs=[
gr.Image(type="pil", label="Входное изображение"),
gr.Dropdown(
choices=list(MODEL_CONFIG.keys()),
value="BiRefNet_HR",
label="Выбор модели"
)
],
outputs=gr.Image(type="filepath", label="Результат"), # Output type is filepath
title="BiRefNet - Интерактивная сегментация",
description=(
"Выберите модель и загрузите изображение для сегментации. "
"Доступные модели: BiRefNet_HR (2048x2048), BiRefNet (1024x1024), BiRefNet-lite-2K (2048x2048)"
),
allow_flagging="never" # Disable flagging, as it is not needed and causing issues
)
if __name__ == "__main__":
iface.launch(share=True) |