import torch from PIL import Image import matplotlib.pyplot as plt from torchvision import transforms import gradio as gr from models.birefnet import BiRefNet import io import tempfile # Импортируем tempfile device = 'cuda' if torch.cuda.is_available() else 'cpu' print("Используем устройство:", device) # Конфигурация моделей MODEL_CONFIG = { "BiRefNet_HR": { "repo": "ZhengPeng7/BiRefNet_HR", "image_size": (2048, 2048) }, "BiRefNet": { "repo": "ZhengPeng7/BiRefNet", "image_size": (1024, 1024) }, "BiRefNet-matting": { "repo": "ZhengPeng7/BiRefNet-matting", "image_size": (1024, 1024) }, "BiRefNet-portrait": { "repo": "ZhengPeng7/BiRefNet-portrait", "image_size": (1024, 1024) }, "BiRefNet-HRSOD": { "repo": "ZhengPeng7/BiRefNet-HRSOD", "image_size": (1024, 1024) }, } # Кэш для загруженных моделей loaded_models = {} def load_model(model_name): if model_name not in loaded_models: print(f"Загрузка модели {model_name}...") model = BiRefNet.from_pretrained(MODEL_CONFIG[model_name]["repo"]) model.to(device).eval() if device == 'cuda': model.half() # FP16 для CUDA loaded_models[model_name] = model return loaded_models[model_name] def extract_object(image, model_name): # Загрузка выбранной модели model = load_model(model_name) config = MODEL_CONFIG[model_name] # Преобразование изображения transform = transforms.Compose([ transforms.Resize(config["image_size"]), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) input_tensor = transform(image).unsqueeze(0) input_tensor = input_tensor.to(device) if device == 'cuda': input_tensor = input_tensor.half() with torch.no_grad(): preds = model(input_tensor)[-1].sigmoid().cpu() mask = transforms.ToPILImage()(preds[0].squeeze()) mask = mask.resize(image.size) result = image.copy() result.putalpha(mask) # Save as PNG to a temporary file with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmp_file: result.save(tmp_file) temp_filepath = tmp_file.name # Get the temporary file path return temp_filepath # Return the filepath to Gradio # Создаем интерфейс с выбором модели iface = gr.Interface( fn=extract_object, inputs=[ gr.Image(type="pil", label="Входное изображение"), gr.Dropdown( choices=list(MODEL_CONFIG.keys()), value="BiRefNet_HR", label="Выбор модели" ) ], outputs=gr.Image(type="filepath", label="Результат"), # Output type is filepath title="BiRefNet - Интерактивная сегментация", description=( "Выберите модель и загрузите изображение для сегментации. " "Доступные модели: BiRefNet_HR (2048x2048), BiRefNet (1024x1024), BiRefNet-lite-2K (2048x2048)" ), allow_flagging="never" # Disable flagging, as it is not needed and causing issues ) if __name__ == "__main__": iface.launch(share=True)