|
|
import torch |
|
|
from diffusers import StableDiffusionPipeline |
|
|
from torch import autocast |
|
|
from pathlib import Path |
|
|
import traceback |
|
|
|
|
|
class StyleTransfer: |
|
|
_instance = None |
|
|
|
|
|
@classmethod |
|
|
def get_instance(cls): |
|
|
if cls._instance is None: |
|
|
cls._instance = cls() |
|
|
return cls._instance |
|
|
|
|
|
def __init__(self): |
|
|
self.pipeline = None |
|
|
self.style_tokens = [] |
|
|
self.styles = [ |
|
|
"dhoni", |
|
|
"mickey_mouse", |
|
|
"balloon", |
|
|
"lion_king", |
|
|
"rose_flower" |
|
|
] |
|
|
self.style_names = [ |
|
|
"Dhoni Style", |
|
|
"Mickey Mouse Style", |
|
|
"Balloon Style", |
|
|
"Lion King Style", |
|
|
"Rose Flower Style" |
|
|
] |
|
|
self.is_initialized = False |
|
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if self.device == "cpu": |
|
|
print("NVIDIA GPU not found. Running on CPU (this will be slower)") |
|
|
|
|
|
def initialize_pipeline(self): |
|
|
if self.is_initialized: |
|
|
return |
|
|
|
|
|
try: |
|
|
print("Initializing Stable Diffusion model...") |
|
|
model_id = "runwayml/stable-diffusion-v1-5" |
|
|
self.pipeline = StableDiffusionPipeline.from_pretrained( |
|
|
model_id, |
|
|
torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, |
|
|
safety_checker=None |
|
|
) |
|
|
self.pipeline = self.pipeline.to(self.device) |
|
|
|
|
|
|
|
|
current_dir = Path(__file__).parent.parent |
|
|
|
|
|
for style, style_name in zip(self.styles, self.style_names): |
|
|
style_path = current_dir / f"{style}.bin" |
|
|
if not style_path.exists(): |
|
|
raise FileNotFoundError(f"Style embedding not found: {style_path}") |
|
|
|
|
|
print(f"Loading style: {style_name}") |
|
|
token = self._load_style_embedding(str(style_path)) |
|
|
self.style_tokens.append(token) |
|
|
print(f"✓ Loaded style: {style_name}") |
|
|
|
|
|
self.is_initialized = True |
|
|
print(f"Model initialization complete! Using device: {self.device}") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error during initialization: {str(e)}") |
|
|
print(traceback.format_exc()) |
|
|
raise |
|
|
|
|
|
def _load_style_embedding(self, embedding_path, token=None): |
|
|
loaded_embeds = torch.load(embedding_path, map_location="cpu") |
|
|
trained_token = list(loaded_embeds.keys())[0] |
|
|
embeds = loaded_embeds[trained_token] |
|
|
|
|
|
|
|
|
expected_dim = self.pipeline.text_encoder.get_input_embeddings().weight.shape[1] |
|
|
vocab_size = self.pipeline.text_encoder.get_input_embeddings().weight.shape[0] |
|
|
current_dim = embeds.shape[0] |
|
|
|
|
|
|
|
|
if current_dim != expected_dim: |
|
|
print(f"Resizing embedding from {current_dim} to {expected_dim}") |
|
|
if current_dim > expected_dim: |
|
|
embeds = embeds[:expected_dim] |
|
|
else: |
|
|
padding = torch.zeros(expected_dim - current_dim, device=embeds.device, dtype=embeds.dtype) |
|
|
embeds = torch.cat([embeds, padding], dim=0) |
|
|
|
|
|
|
|
|
embeds = embeds.unsqueeze(0) |
|
|
|
|
|
|
|
|
dtype = self.pipeline.text_encoder.get_input_embeddings().weight.dtype |
|
|
embeds = embeds.to(dtype) |
|
|
|
|
|
|
|
|
token = token if token is not None else trained_token |
|
|
num_added_tokens = self.pipeline.tokenizer.add_tokens(token) |
|
|
|
|
|
if num_added_tokens > 0: |
|
|
|
|
|
self.pipeline.text_encoder.resize_token_embeddings(len(self.pipeline.tokenizer)) |
|
|
|
|
|
|
|
|
token_id = self.pipeline.tokenizer.convert_tokens_to_ids(token) |
|
|
if token_id < self.pipeline.text_encoder.get_input_embeddings().weight.shape[0]: |
|
|
self.pipeline.text_encoder.get_input_embeddings().weight.data[token_id] = embeds |
|
|
else: |
|
|
print(f"Warning: Token ID {token_id} is out of bounds. Skipping embedding assignment.") |
|
|
|
|
|
return token |
|
|
|
|
|
def generate_artwork(self, prompt, selected_style): |
|
|
try: |
|
|
|
|
|
style_idx = self.style_names.index(selected_style) |
|
|
|
|
|
|
|
|
styled_prompt = f"{prompt}, {self.style_tokens[style_idx]}" |
|
|
|
|
|
|
|
|
generator_seed = 42 |
|
|
torch.manual_seed(generator_seed) |
|
|
if self.device == "cuda": |
|
|
torch.cuda.manual_seed(generator_seed) |
|
|
|
|
|
|
|
|
with autocast(self.device): |
|
|
base_image = self.pipeline( |
|
|
styled_prompt, |
|
|
num_inference_steps=50, |
|
|
guidance_scale=7.5, |
|
|
generator=torch.Generator(self.device).manual_seed(generator_seed) |
|
|
).images[0] |
|
|
|
|
|
|
|
|
with autocast(self.device): |
|
|
enhanced_image = self.pipeline( |
|
|
styled_prompt, |
|
|
num_inference_steps=50, |
|
|
guidance_scale=7.5, |
|
|
callback=self._enhance_colors, |
|
|
callback_steps=5, |
|
|
generator=torch.Generator(self.device).manual_seed(generator_seed) |
|
|
).images[0] |
|
|
|
|
|
return base_image, enhanced_image |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in generate_artwork: {e}") |
|
|
raise |
|
|
|
|
|
def _enhance_colors(self, i, t, latents): |
|
|
if i % 5 == 0: |
|
|
try: |
|
|
|
|
|
latents_copy = latents.detach().clone() |
|
|
latents_copy.requires_grad_(True) |
|
|
|
|
|
|
|
|
loss = self._calculate_color_distance(latents_copy) |
|
|
|
|
|
|
|
|
if loss is not None and loss.requires_grad: |
|
|
grads = torch.autograd.grad( |
|
|
outputs=loss, |
|
|
inputs=latents_copy, |
|
|
allow_unused=True, |
|
|
retain_graph=False |
|
|
) |
|
|
|
|
|
if grads and grads[0] is not None: |
|
|
|
|
|
grad_tensor = grads[0].detach() |
|
|
if grad_tensor.shape == latents.shape: |
|
|
return latents - 0.1 * grad_tensor |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error in color enhancement: {e}") |
|
|
|
|
|
|
|
|
return latents |
|
|
|
|
|
def _calculate_color_distance(self, images): |
|
|
|
|
|
if not images.requires_grad: |
|
|
images = images.detach().requires_grad_(True) |
|
|
|
|
|
|
|
|
images = images.float() / 2 + 0.5 |
|
|
|
|
|
|
|
|
red = images[:,0:1] |
|
|
green = images[:,1:2] |
|
|
blue = images[:,2:3] |
|
|
|
|
|
|
|
|
rg_distance = ((red - green) ** 2).mean() |
|
|
rb_distance = ((red - blue) ** 2).mean() |
|
|
gb_distance = ((green - blue) ** 2).mean() |
|
|
|
|
|
return (rg_distance + rb_distance + gb_distance) * 100 |