| import torch
|
| import numpy as np
|
| import os
|
| import sys
|
| from concurrent.futures import ThreadPoolExecutor
|
| from diffusers.utils import load_image
|
| from diffusers import EulerDiscreteScheduler
|
| from huggingface_hub import hf_hub_download
|
| from photomaker import PhotoMakerStableDiffusionXLPipeline
|
| from rembg import remove
|
| from PIL import Image
|
|
|
|
|
| styles = {
|
| "Cinematic HD": ("cinematic HD {prompt}", "low quality"),
|
| }
|
|
|
|
|
| base_model_path = 'SG161222/RealVisXL_V3.0'
|
| person_image_folder = 'in'
|
| environment_image_path = 'environment1024.jpg'
|
|
|
| try:
|
| if torch.cuda.is_available():
|
| device = "cuda"
|
| elif sys.platform == "darwin" and torch.backends.mps.is_available():
|
| device = "mps"
|
| else:
|
| device = "cpu"
|
| except:
|
| device = "cpu"
|
|
|
| MAX_SEED = np.iinfo(np.int32).max
|
| DEFAULT_STYLE_NAME = "Cinematic HD"
|
|
|
|
|
| photomaker_ckpt = hf_hub_download(repo_id="TencentARC/PhotoMaker", filename="photomaker-v1.bin", repo_type="model")
|
|
|
| if device == "mps":
|
| torch_dtype = torch.float16
|
| else:
|
| torch_dtype = torch.bfloat16
|
|
|
| pipe = PhotoMakerStableDiffusionXLPipeline.from_pretrained(
|
| base_model_path,
|
| torch_dtype=torch_dtype,
|
| use_safetensors=True,
|
| variant="fp16",
|
| ).to(device)
|
|
|
| pipe.load_photomaker_adapter(
|
| os.path.dirname(photomaker_ckpt),
|
| subfolder="",
|
| weight_name=os.path.basename(photomaker_ckpt),
|
| trigger_word="img"
|
| )
|
| pipe.id_encoder.to(device)
|
|
|
| pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
| pipe.fuse_lora()
|
|
|
| def remove_background(image_path):
|
| try:
|
| input_image = Image.open(image_path)
|
| output_image = remove(input_image)
|
| return output_image
|
| except Exception as e:
|
| print(f"Error in remove_background: {e}")
|
| return None
|
|
|
| def overlay_images(background, person_images):
|
| try:
|
| background = background.convert("RGBA")
|
| for person_image in person_images:
|
| person_image = person_image.convert("RGBA")
|
| x_offset = (background.width - person_image.width) // 2
|
| y_offset = (background.height - person_image.height) // 2
|
| background.paste(person_image, (x_offset, y_offset), person_image)
|
| return background.convert("RGB")
|
| except Exception as e:
|
| print(f"Error in overlay_images: {e}")
|
| return background
|
|
|
| def generate_combined_image(combined_image, prompt, negative_prompt, num_steps, style_strength_ratio, num_outputs, guidance_scale, seed):
|
| try:
|
| person_token_id = pipe.tokenizer.convert_tokens_to_ids("img")
|
| input_ids = pipe.tokenizer.encode(prompt)
|
|
|
| if person_token_id not in input_ids:
|
| raise ValueError(f"Cannot find the trigger word 'img' in text prompt!")
|
| if input_ids.count(person_token_id) > 1:
|
| raise ValueError(f"Cannot use multiple trigger words 'img' in text prompt!")
|
|
|
| prompt, negative_prompt = apply_style(DEFAULT_STYLE_NAME, prompt, negative_prompt)
|
|
|
| input_id_images = [combined_image]
|
|
|
| generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
| start_merge_step = int(float(style_strength_ratio) / 100 * num_steps)
|
| if start_merge_step > 30:
|
| start_merge_step = 30
|
|
|
| images = pipe(
|
| prompt=prompt,
|
| width=1280,
|
| height=720,
|
| input_id_images=input_id_images,
|
| negative_prompt=negative_prompt,
|
| num_images_per_prompt=num_outputs,
|
| num_inference_steps=num_steps,
|
| start_merge_step=start_merge_step,
|
| generator=generator,
|
| guidance_scale=guidance_scale,
|
| ).images
|
|
|
| return images[0]
|
| except Exception as e:
|
| print(f"Error in generate_combined_image: {e}")
|
| return None
|
|
|
| def apply_style(style_name: str, positive: str, negative: str = ""):
|
| p, n = styles.get(style_name, styles["Cinematic HD"])
|
| return p.replace("{prompt}", positive), n + ' ' + negative
|
|
|
| def process_image(image_path):
|
| return remove_background(image_path)
|
|
|
| def main():
|
| prompt = "cinematic person img vicking, 35mm photograph, film, bokeh, professional, 4k, highly detailed,"
|
| negative_prompt = "nsfw, lowres, bad anatomy, bad hands, text, error, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality, normal quality, jpeg artifacts, signature, watermark, username, blurry"
|
| num_steps = 12
|
| style_strength_ratio = 48
|
| num_outputs = 1
|
| guidance_scale = 7.8
|
| seed = 52
|
|
|
|
|
| try:
|
| environment_image = load_image(environment_image_path)
|
| except Exception as e:
|
| print(f"Error loading environment image: {e}")
|
| return
|
|
|
|
|
| try:
|
| person_image_paths = [os.path.join(person_image_folder, filename) for filename in os.listdir(person_image_folder) if filename.lower().endswith(('.png', '.jpg', '.jpeg'))]
|
| with ThreadPoolExecutor() as executor:
|
| person_images = list(executor.map(process_image, person_image_paths))
|
| person_images = [img for img in person_images if img is not None]
|
| except Exception as e:
|
| print(f"Error processing person images: {e}")
|
| return
|
|
|
|
|
| combined_image = overlay_images(environment_image, person_images)
|
|
|
|
|
| combined_image_path = "combined_image.jpg"
|
| try:
|
| combined_image.save(combined_image_path)
|
| except Exception as e:
|
| print(f"Error saving combined image: {e}")
|
| return
|
|
|
|
|
| try:
|
| combined_image = load_image(combined_image_path)
|
| except Exception as e:
|
| print(f"Error loading combined image: {e}")
|
| return
|
|
|
|
|
| result_image = generate_combined_image(
|
| combined_image,
|
| prompt,
|
| negative_prompt,
|
| num_steps,
|
| style_strength_ratio,
|
| num_outputs,
|
| guidance_scale,
|
| seed
|
| )
|
|
|
| if result_image:
|
| try:
|
| result_image.save("result.jpg")
|
| print("Image saved as result.jpg")
|
| except Exception as e:
|
| print(f"Error saving result image: {e}")
|
| else:
|
| print("No image generated")
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|