vandalizer-backend / tasks.py
zeyadcode's picture
Sync from GitHub via hub-sync
9fab1c1 verified
Raw
History Blame Contribute Delete
6.31 kB
# %%
from celery import Celery
import config as config
import torch
from utils import get_prompt_list, plot_groundingdino_boxes, save_visual_mask, save_visual_mask_from_binary, blur_img_with_mask
from PIL import Image
from transformers import BatchFeature
import numpy as np
import json
from typing import Any
from services import model_manager
MODELS: dict[str, Any] = {
"detector": None,
"detector_processor": None,
"segmentor": None,
"inpaintor": None,
"remover": None,
}
celery_app = Celery(
"worker",
broker=config.CELERY_BROKER_URL,
backend=config.CELERY_RESULT_BACKEND,
)
celery_app.conf.update(
broker_connection_retry_on_startup=True,
result_expires=config.JOB_TTL_SECONDS,
task_track_started=True,
)
def _detect_objects(job_id: str, prompt: str) -> dict:
prompt_list = get_prompt_list(prompt)
job_path = config.UPLOAD_DIR / job_id
detection_path = job_path / config.DETECTOR_OUT_PATH
if detection_path.exists():
detection_path.unlink()
img = Image.open(job_path / config.INPUT_IMG_NAME) # W x H
model = model_manager.get_detector_model(MODELS)
processor = model_manager.get_detector_processor(MODELS)
inputs = processor(images=img, text=[prompt_list], return_tensors="pt")
inputs = {name: t for name, t in inputs.items()}
outputs = model(inputs)
outputs = BatchFeature(
{
"logits": torch.tensor(outputs["logits"]),
"pred_boxes": torch.tensor(outputs["pred_boxes"]),
}
)
results = processor.post_process_grounded_object_detection(
outputs,
threshold=0.1,
target_sizes=[img.size[::-1]],
)
result = results[0]
result["scores"] = result["scores"].tolist()
result["labels"] = result["labels"].tolist()
result["boxes"] = result["boxes"].int().tolist()
result["text_labels"] = [prompt_list[label] for label in result["labels"]]
with open(detection_path, "w") as f:
json.dump(result, f, indent=2)
return result
def _save_empty_mask(job_path, image_size):
width, height = image_size
combined_mask = np.zeros((height, width), dtype=np.uint8)
Image.fromarray(combined_mask, mode="L").save(job_path / config.SEGMENTOR_OUT_BIN_PATH)
save_visual_mask_from_binary(combined_mask, job_path / config.SEGMENTOR_OUT_VISUAL_PATH)
return combined_mask
def _segment_objects(job_id: str, bboxes=None, points=None, point_labels=None):
job_path = config.UPLOAD_DIR / job_id
img = Image.open(job_path / config.INPUT_IMG_NAME)
if not bboxes and not points:
_save_empty_mask(job_path, img.size)
return False
model = model_manager.get_segmentor_model(MODELS)
results = model(img, bboxes=bboxes, points=points, labels=point_labels)
mask_data = results[0].masks.data.cpu().numpy() # n_masks x H x W
if mask_data.size == 0:
combined_mask = _save_empty_mask(job_path, img.size)
else:
combined_mask = np.any(mask_data, axis=0).astype(np.uint8) * 255
save_bin_path = job_path / config.SEGMENTOR_OUT_BIN_PATH
save_visual_path = job_path / config.SEGMENTOR_OUT_VISUAL_PATH
Image.fromarray(combined_mask).save(save_bin_path)
save_visual_mask(combined_mask, save_visual_path)
return results[0]
@celery_app.task
def detect_objects(job_id: str, prompt: str) -> dict:
return _detect_objects(job_id, prompt)
@celery_app.task(bind=True)
def segment_objects(self, job_id: str, bboxes=None, points=None, point_labels=None):
result = _segment_objects(job_id, bboxes=bboxes, points=points, point_labels=point_labels)
if self.request.id is None:
return result
return True
@celery_app.task(bind=True)
def generate_mask(self, job_id: str, prompt: str) -> dict:
detection_result = _detect_objects(job_id, prompt)
bboxes = detection_result.get("boxes", [])
_segment_objects(job_id, bboxes=bboxes)
if self.request.id is None:
return detection_result
return {
"boxes": len(bboxes),
"labels": detection_result.get("text_labels", []),
}
@celery_app.task(bind=True)
def inpaint(
self,
job_id: str,
mode: str = "blur",
positive_prompt: str = "",
strength=0.7,
num_inference_steps: int = 4,
) -> bool:
job_path = config.UPLOAD_DIR / job_id
save_path = job_path / config.INPAINTOR_OUT_PATH
if save_path.exists():
save_path.unlink()
orig_img = Image.open(job_path / config.INPUT_IMG_NAME).convert("RGB")
mask_img = Image.open(job_path / config.SEGMENTOR_OUT_BIN_PATH).convert("L")
if mode == "diffusion":
model = model_manager.get_inpaintor_model(MODELS)
model_size = (config.INPAINTOR_IMAGE_SIZE, config.INPAINTOR_IMAGE_SIZE)
generated = model(
prompt=positive_prompt,
image=orig_img.resize(model_size),
mask_image=mask_img.resize(model_size, Image.Resampling.NEAREST),
num_inference_steps=num_inference_steps,
strength= strength,
guidance_scale=0,
).images[0]
result = generated.resize(orig_img.size)
elif mode == "remove":
model = model_manager.get_removing_model(MODELS)
result = model(orig_img, mask_img)
elif mode == "blur":
result = blur_img_with_mask(orig_img, mask_img)
else:
raise ValueError(f"Unsupported inpaint mode: {mode}")
result.convert("RGB").save(save_path)
if self.request.id is None:
return result
return True
# %%
if __name__ == "__main__":
import matplotlib.pyplot as plt
prompt = "head"
img = Image.open(config.UPLOAD_DIR / config.DEBUG_JOB_ID / config.INPUT_IMG_NAME)
# %%
detection_res = detect_objects(job_id=config.DEBUG_JOB_ID, prompt=prompt)
plot_groundingdino_boxes(img, detection_res)
# %%
segment_res = segment_objects.run(job_id=config.DEBUG_JOB_ID, bboxes=detection_res["boxes"])
# %%
inpainted_res = inpaint(
job_id=config.DEBUG_JOB_ID,
mode="diffusion",
positive_prompt="red hat high quality",
num_inference_steps=4,
) # type: ignore
# %%
img = segment_res.plot()
plt.subplot(121)
plt.imshow(img[:, :, ::-1])
plt.axis("off")
plt.show()