vandalizer-backend / services /model_manager.py
zeyadcode's picture
Sync from GitHub via hub-sync
e9c3e3f verified
Raw
History Blame Contribute Delete
2.2 kB
from transformers import AutoProcessor
from ultralytics import SAM
import openvino as ov
import config as config
from optimum.intel import OVStableDiffusionXLInpaintPipeline
from simple_lama_inpainting import SimpleLama
def get_detector_model(MODELS: dict):
if MODELS["detector"] is None:
core = ov.Core()
model = core.read_model(config.DETECTOR_MODEL_PATH)
MODELS["detector"] = core.compile_model(model, "CPU")
return MODELS["detector"]
def get_detector_processor(MODELS: dict):
if MODELS["detector_processor"] is None:
MODELS["detector_processor"] = AutoProcessor.from_pretrained(config.DETECTOR_MODEL_NAME, use_fast=True)
return MODELS["detector_processor"]
def get_segmentor_model(MODELS: dict):
if MODELS["segmentor"] is None:
MODELS["segmentor"] = SAM(config.SEGMENTOR_MODEL_NAME)
return MODELS["segmentor"]
def get_inpaintor_model(MODELS: dict):
if MODELS["inpaintor"] is None:
model_source = config.INPAINTOR_MODEL_PATH if config.INPAINTOR_MODEL_PATH.exists() else config.INPAINTOR_MODEL_NAME
kwargs = {"device": config.INPAINTOR_DEVICE}
if model_source == config.INPAINTOR_MODEL_NAME:
kwargs["export"] = True
pipe = OVStableDiffusionXLInpaintPipeline.from_pretrained(
str(model_source),
**kwargs,
)
# ❌ REMOVE OR COMMENT OUT THIS ENTIRE BLOCK ❌
# pipe.reshape(
# batch_size=1,
# height=config.INPAINTOR_IMAGE_SIZE,
# width=config.INPAINTOR_IMAGE_SIZE,
# num_images_per_prompt=1,
# )
# Keep compile!
pipe.compile()
MODELS["inpaintor"] = pipe
return MODELS["inpaintor"]
def get_removing_model(MODELS: dict):
if MODELS["remover"] is None:
import torch
# Intercept the load function to force CPU mapping
original_load = torch.jit.load
torch.jit.load = lambda *a, **kw: original_load(*a, **{**kw, "map_location": "cpu"})
MODELS["remover"] = SimpleLama()
# Restore the original PyTorch load function right after
torch.jit.load = original_load
return MODELS["remover"]