habulaj's picture
Update app.py
7ccc668 verified
import gradio as gr
import spaces
from transformers import AutoModelForImageSegmentation
import torch
from torchvision import transforms
from typing import Union, Tuple
from PIL import Image
import io
import base64
import os
import uuid
from typing import Optional
import uvicorn
import requests
# FastAPI imports para endpoints customizados
from fastapi import UploadFile, File, Form
from fastapi.responses import JSONResponse, FileResponse
# Função para carregar imagem (substitui loadimg)
def load_img(image_input: Union[str, Image.Image], output_type: str = "pil") -> Image.Image:
"""
Carrega uma imagem de URL, caminho de arquivo ou retorna PIL Image diretamente.
Substitui a função loadimg do pacote loadimg.
"""
if isinstance(image_input, Image.Image):
return image_input
if isinstance(image_input, str):
# Se for URL
if image_input.startswith("http://") or image_input.startswith("https://"):
response = requests.get(image_input, timeout=30)
response.raise_for_status()
return Image.open(io.BytesIO(response.content))
# Se for caminho de arquivo
elif os.path.exists(image_input):
return Image.open(image_input)
else:
raise ValueError(f"Não foi possível carregar a imagem: {image_input}")
raise ValueError(f"Tipo de entrada não suportado: {type(image_input)}")
torch.set_float32_matmul_precision(["high", "highest"][0])
# Carregar modelo durante startup (como no original)
birefnet = AutoModelForImageSegmentation.from_pretrained(
"ZhengPeng7/BiRefNet", trust_remote_code=True
)
birefnet.to("cuda")
transform_image = transforms.Compose(
[
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
"""
Remove the background from an image and return both the transparent version and the original.
This function performs background removal using a BiRefNet segmentation model. It is intended for use
with image input (either uploaded or from a URL). The function returns a transparent PNG version of the image
with the background removed, along with the original RGB version for comparison.
Args:
image (PIL.Image or str): The input image, either as a PIL object or a filepath/URL string.
Returns:
tuple:
- origin (PIL.Image): The original RGB image, unchanged.
- processed_image (PIL.Image): The input image with the background removed and transparency applied.
"""
im = load_img(image, output_type="pil")
im = im.convert("RGB")
origin = im.copy()
processed_image = process(im)
return (origin, processed_image)
@spaces.GPU
def process(image: Image.Image) -> Image.Image:
"""
Apply BiRefNet-based image segmentation to remove the background.
This function preprocesses the input image, runs it through a BiRefNet segmentation model to obtain a mask,
and applies the mask as an alpha (transparency) channel to the original image.
Args:
image (PIL.Image): The input RGB image.
Returns:
PIL.Image: The image with the background removed, using the segmentation mask as transparency.
"""
image_size = image.size
input_images = transform_image(image).unsqueeze(0).to("cuda")
# Prediction
with torch.no_grad():
preds = birefnet(input_images)[-1].sigmoid().cpu()
pred = preds[0].squeeze()
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize(image_size)
image.putalpha(mask)
return image
def process_file(f: str) -> str:
"""
Load an image file from disk, remove the background, and save the output as a transparent PNG.
Args:
f (str): Filepath of the image to process.
Returns:
str: Path to the saved PNG image with background removed.
"""
name_path = f.rsplit(".", 1)[0] + ".png"
im = load_img(f, output_type="pil")
im = im.convert("RGB")
transparent = process(im)
transparent.save(name_path)
return name_path
# Diretório para salvar imagens processadas (usando sistema de arquivos temporários do Gradio)
GRADIO_TMP_DIR = "/tmp/gradio"
os.makedirs(GRADIO_TMP_DIR, exist_ok=True)
OUTPUT_DIR = os.path.join(GRADIO_TMP_DIR, "output_images")
os.makedirs(OUTPUT_DIR, exist_ok=True)
# URL base para servir as imagens
SPACE_ID = os.getenv("SPACE_ID", "habulaj-background-removal")
# Garantir que SPACE_ID não tenha barras ou caracteres inválidos
SPACE_ID = SPACE_ID.replace("/", "-").strip()
BASE_URL = os.getenv("BASE_URL", f"https://{SPACE_ID}.hf.space")
# Garantir que BASE_URL está correta (sem barras duplas ou caracteres inválidos)
BASE_URL = BASE_URL.rstrip("/")
# ========== Gradio Setup ==========
slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
image_upload = gr.Image(label="Upload an image")
image_file_upload = gr.Image(label="Upload an image", type="filepath")
url_input = gr.Textbox(label="Paste an image URL")
output_file = gr.File(label="Output PNG File")
# Example images
try:
chameleon = load_img("butterfly.jpg", output_type="pil")
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
except:
chameleon = None
url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon] if chameleon else None, api_name="image")
tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"] if os.path.exists("butterfly.jpg") else None, api_name="png")
# Função para remover background via API (retorna URL) - aceita file path também
def remove_background_api(image_url: Optional[str] = None, image_base64: Optional[str] = None, image_file_path: Optional[str] = None):
"""Remove background e retorna URL da imagem processada."""
try:
# Carregar imagem
image = None
if image_file_path:
if isinstance(image_file_path, str) and os.path.exists(image_file_path):
image = Image.open(image_file_path)
image = image.convert("RGB")
elif hasattr(image_file_path, 'name'): # UploadFile object
image = Image.open(image_file_path)
image = image.convert("RGB")
elif image_base64:
if image_base64.startswith("data:image"):
image_base64 = image_base64.split(",")[1]
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
image = image.convert("RGB")
elif image_url:
image = load_img(image_url, output_type="pil")
image = image.convert("RGB")
else:
return {"success": False, "error": "Nenhuma imagem fornecida."}
# Processar imagem
processed_image = process(image)
# Salvar imagem processada em diretório temporário do Gradio
image_id = str(uuid.uuid4())
output_path = os.path.join(OUTPUT_DIR, f"{image_id}.png")
processed_image.save(output_path, "PNG")
# Construir URL usando o sistema de arquivos temporários do Gradio
# O Gradio serve arquivos de /tmp/gradio/ através de /gradio_api/file=
image_url_result = f"{BASE_URL}/gradio_api/file={output_path}"
return {
"success": True,
"image_url": image_url_result,
"message": "Background removido com sucesso"
}
except Exception as e:
print(f"Erro ao processar imagem: {str(e)}")
import traceback
traceback.print_exc()
return {"success": False, "error": str(e)}
# Criar Blocks para poder adicionar endpoints customizados
with gr.Blocks(title="Background Removal Tool") as blocks:
# Adicionar as tabs dentro do Blocks
with gr.Tabs():
with gr.Tab("Image Upload"):
gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon] if chameleon else None, api_name="image")
with gr.Tab("URL Input"):
gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
with gr.Tab("File Output"):
gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"] if os.path.exists("butterfly.jpg") else None, api_name="png")
with gr.Tab("API"):
# Interface para testar a API
api_url_input = gr.Textbox(label="Image URL", placeholder="https://example.com/image.jpg")
api_base64_input = gr.Textbox(label="Image Base64", placeholder="data:image/png;base64,...", lines=3)
api_output = gr.JSON(label="Resultado")
api_btn = gr.Button("Processar")
api_btn.click(
fn=remove_background_api,
inputs=[api_url_input, api_base64_input],
outputs=api_output
)
# Adicionar Interface com api_name para expor via API do Gradio
# Isso cria automaticamente o endpoint /api/remove_background
gr.Interface(
fn=remove_background_api,
inputs=[
gr.Textbox(label="Image URL", placeholder="https://example.com/image.jpg"),
gr.Textbox(label="Image Base64", placeholder="data:image/png;base64,...", lines=3),
gr.File(label="Image File", type="filepath")
],
outputs=gr.JSON(),
api_name="remove_background",
title="Remove Background API"
)
demo = blocks
# Adicionar endpoints FastAPI diretamente no app do Gradio (fora do contexto Blocks)
# Isso garante que os endpoints sejam registrados antes do launch
@demo.app.post("/remove-background")
async def remove_background_fastapi(
image_url: Optional[str] = Form(None),
image_base64: Optional[str] = Form(None),
image_file: Optional[UploadFile] = File(None)
):
"""Endpoint FastAPI para remover background."""
try:
# Carregar imagem
image = None
if image_file:
contents = await image_file.read()
image = Image.open(io.BytesIO(contents))
image = image.convert("RGB")
elif image_base64:
if image_base64.startswith("data:image"):
image_base64 = image_base64.split(",")[1]
image_data = base64.b64decode(image_base64)
image = Image.open(io.BytesIO(image_data))
image = image.convert("RGB")
elif image_url:
image = load_img(image_url, output_type="pil")
image = image.convert("RGB")
else:
return JSONResponse(
status_code=400,
content={"success": False, "error": "Nenhuma imagem fornecida."}
)
# Processar imagem
processed_image = process(image)
# Salvar imagem processada em diretório temporário do Gradio
image_id = str(uuid.uuid4())
output_path = os.path.join(OUTPUT_DIR, f"{image_id}.png")
processed_image.save(output_path, "PNG")
# Construir URL usando o sistema de arquivos temporários do Gradio
# O Gradio serve arquivos de /tmp/gradio/ através de /gradio_api/file=
image_url_result = f"{BASE_URL}/gradio_api/file={output_path}"
return JSONResponse(content={
"success": True,
"image_url": image_url_result,
"message": "Background removido com sucesso"
})
except Exception as e:
print(f"Erro ao processar imagem: {str(e)}")
import traceback
traceback.print_exc()
return JSONResponse(
status_code=500,
content={"success": False, "error": str(e)}
)
@demo.app.get("/images/{image_id}")
async def get_image_fastapi(image_id: str):
"""Serve a processed image by ID (fallback endpoint)."""
image_path = os.path.join(OUTPUT_DIR, f"{image_id}.png")
if not os.path.exists(image_path):
return JSONResponse(status_code=404, content={"error": "Imagem não encontrada"})
return FileResponse(image_path, media_type="image/png")
if __name__ == "__main__":
# Desabilitar SSR para evitar problemas com endpoints customizados
demo.launch(show_error=True, ssr_mode=False)