# generate.py # --- VERSION 13 (Correct Upscaler Loading) --- print("--- RUNNING GENERATE.PY VERSION 13 (Correct Upscaler Loading) ---") # --- MONKEY-PATCH FOR OLD TORCHVISION --- try: import sys import torchvision.transforms.functional as F sys.modules['torchvision.transforms.functional_tensor'] = F print("--- Successfully applied torchvision monkey-patch. ---") except Exception as e: print(f"--- Could not apply torchvision monkey-patch: {e} ---") # --- END OF PATCH --- import torch import cv2 import os import logging import uuid import traceback from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput from transformers import CLIPVisionModelWithProjection from insightface.app import FaceAnalysis from insightface.utils import face_align from huggingface_hub import hf_hub_download from storage3.utils import StorageException from realesrgan import RealESRGANer # <-- IMPORT THE CORRECT CLASS from basicsr.archs.rrdbnet_arch import RRDBNet from gfpgan import GFPGANer import config import utils from database import supabase # --- Setup Logging --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) # --- Main Generation Service --- class GenerationService: def __init__(self): logger.info("Initializing Generation Service...") self.device = "cuda" if torch.cuda.is_available() else "cpu" self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32 logger.info(f"Using device: {self.device} with dtype: {self.torch_dtype}") base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE" vae_model_path = "stabilityai/sd-vae-ft-mse" try: # --- AI Models --- self.face_app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider' if self.device == "cuda" else 'CPUExecutionProvider']) self.face_app.prepare(ctx_id=0, det_size=(640, 640)) cv2.setNumThreads(1) vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=self.torch_dtype) self.pipe = StableDiffusionPipeline.from_pretrained( base_model_path, torch_dtype=self.torch_dtype, scheduler=DDIMScheduler( num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ), vae=vae, feature_extractor=None, safety_checker=None ).to(self.device) # --- CORRECTED UPSCALER LOADING --- logger.info("Loading Real-ESRGAN upscaler model...") model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4) self.upsampler = RealESRGANer( scale=4, model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth', dni_weight=None, model=model, tile=0, tile_pad=10, pre_pad=0, half=True if self.torch_dtype == torch.float16 else False, gpu_id=0 if self.device == "cuda" else None ) logger.info("Upscaler model loaded.") logger.info("All models loaded successfully.") except Exception as e: logger.error(f"Fatal error during model loading: {e}") raise RuntimeError(f"Could not initialize GenerationService: {e}") from e def _upscale_image(self, image_path: str) -> str: """Upscales an image using Real-ESRGAN.""" try: img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED) # The enhance method returns the upscaled image and its type output, _ = self.upsampler.enhance(img, outscale=4) cv2.imwrite(image_path, output) logger.info(f"Successfully upscaled image: {image_path}") return image_path except Exception as e: logger.error(f"Failed to upscale image {image_path}: {e}") return image_path def generate_magic_image(self, face_images: list, gender: str, prompt: str, plan: str = 'free') -> str | None: logger.info(f"Starting image generation process for a user on the '{plan}' plan.") full_prompt = f"{prompt}, 4k, high-resolution, photorealistic, masterpiece, single person, solo portrait, centered composition" negative_prompt = "multiple people, group photo, crowd, two faces, three faces, multiple faces, collage, ugly, deformed, blurry, low quality" faceid_all_embeds = [] for image_path in face_images: try: face = cv2.imread(image_path) if face is None: continue faces = self.face_app.get(face) if faces: faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0) faceid_all_embeds.append(faceid_embed) except Exception as e: logger.error(f"Error processing face image {image_path}: {e}") if not faceid_all_embeds: logger.error("No faces were detected in any of the provided images.") return None average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0) logger.info("Calling the generation pipeline...") try: positive_embedding = average_embedding.unsqueeze(0) negative_embedding = torch.zeros_like(positive_embedding) final_embedding = torch.cat([negative_embedding, positive_embedding], dim=0) output = self.pipe( prompt=full_prompt, negative_prompt=negative_prompt, ip_adapter_image_embeds=[final_embedding], num_inference_steps=40, guidance_scale=7.5, width=512, height=768, ) image = output.images[0] if isinstance(output, StableDiffusionPipelineOutput) else output[0][0] temp_dir = "temp_images" os.makedirs(temp_dir, exist_ok=True) local_path = os.path.join(temp_dir, f"{uuid.uuid4()}.png") image.save(local_path) if plan == 'free': utils.add_watermark(local_path, "@MagicFaceBot") else: self._upscale_image(local_path) storage_path = f"public/{os.path.basename(local_path)}" with open(local_path, 'rb') as f: supabase.storage.from_(config.SUPABASE_BUCKET_NAME).upload( path=storage_path, file=f, file_options={"content-type": "image/png"} ) public_url = supabase.storage.from_(config.SUPABASE_BUCKET_NAME).get_public_url(storage_path) os.remove(local_path) return public_url except Exception as e: logger.error("An unexpected error occurred. Full traceback below:") traceback.print_exc() logger.error(f"Error summary: {e}") if 'local_path' in locals() and os.path.exists(local_path): os.remove(local_path) return None # --- Example Usage (for testing) --- if __name__ == '__main__': if os.path.exists("test_face.jpg"): logger.info("Running a test generation and upload...") service = GenerationService() result_url = service.generate_magic_image( face_images=["test_face.jpg"], gender="Female", prompt="A beautiful portrait of a princess in a magical forest, fantasy art", plan='paid' ) if result_url: print(f"\n✅ Test successful! Image URL: {result_url}") print("Check the image at the URL. It should be high-resolution and have no watermark.") else: print(f"\n❌ Test failed. Please check the logs for details.") else: print("To run a test, place an image named 'test_face.jpg' in the root directory.")