MagicFaceTG / generate.py
rastof9's picture
Update Gradio app with multiple files
e6c006b verified
# generate.py
# --- VERSION 13 (Correct Upscaler Loading) ---
print("--- RUNNING GENERATE.PY VERSION 13 (Correct Upscaler Loading) ---")
# --- MONKEY-PATCH FOR OLD TORCHVISION ---
try:
import sys
import torchvision.transforms.functional as F
sys.modules['torchvision.transforms.functional_tensor'] = F
print("--- Successfully applied torchvision monkey-patch. ---")
except Exception as e:
print(f"--- Could not apply torchvision monkey-patch: {e} ---")
# --- END OF PATCH ---
import torch
import cv2
import os
import logging
import uuid
import traceback
from diffusers import StableDiffusionPipeline, DDIMScheduler, AutoencoderKL
from diffusers.pipelines.stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
from transformers import CLIPVisionModelWithProjection
from insightface.app import FaceAnalysis
from insightface.utils import face_align
from huggingface_hub import hf_hub_download
from storage3.utils import StorageException
from realesrgan import RealESRGANer # <-- IMPORT THE CORRECT CLASS
from basicsr.archs.rrdbnet_arch import RRDBNet
from gfpgan import GFPGANer
import config
import utils
from database import supabase
# --- Setup Logging ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
# --- Main Generation Service ---
class GenerationService:
def __init__(self):
logger.info("Initializing Generation Service...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.torch_dtype = torch.float16 if self.device == "cuda" else torch.float32
logger.info(f"Using device: {self.device} with dtype: {self.torch_dtype}")
base_model_path = "SG161222/Realistic_Vision_V4.0_noVAE"
vae_model_path = "stabilityai/sd-vae-ft-mse"
try:
# --- AI Models ---
self.face_app = FaceAnalysis(name="buffalo_l", providers=['CUDAExecutionProvider' if self.device == "cuda" else 'CPUExecutionProvider'])
self.face_app.prepare(ctx_id=0, det_size=(640, 640))
cv2.setNumThreads(1)
vae = AutoencoderKL.from_pretrained(vae_model_path).to(dtype=self.torch_dtype)
self.pipe = StableDiffusionPipeline.from_pretrained(
base_model_path,
torch_dtype=self.torch_dtype,
scheduler=DDIMScheduler(
num_train_timesteps=1000, beta_start=0.00085, beta_end=0.012,
beta_schedule="scaled_linear", clip_sample=False,
set_alpha_to_one=False, steps_offset=1,
),
vae=vae, feature_extractor=None, safety_checker=None
).to(self.device)
# --- CORRECTED UPSCALER LOADING ---
logger.info("Loading Real-ESRGAN upscaler model...")
model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
self.upsampler = RealESRGANer(
scale=4,
model_path='https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
dni_weight=None,
model=model,
tile=0,
tile_pad=10,
pre_pad=0,
half=True if self.torch_dtype == torch.float16 else False,
gpu_id=0 if self.device == "cuda" else None
)
logger.info("Upscaler model loaded.")
logger.info("All models loaded successfully.")
except Exception as e:
logger.error(f"Fatal error during model loading: {e}")
raise RuntimeError(f"Could not initialize GenerationService: {e}") from e
def _upscale_image(self, image_path: str) -> str:
"""Upscales an image using Real-ESRGAN."""
try:
img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
# The enhance method returns the upscaled image and its type
output, _ = self.upsampler.enhance(img, outscale=4)
cv2.imwrite(image_path, output)
logger.info(f"Successfully upscaled image: {image_path}")
return image_path
except Exception as e:
logger.error(f"Failed to upscale image {image_path}: {e}")
return image_path
def generate_magic_image(self, face_images: list, gender: str, prompt: str, plan: str = 'free') -> str | None:
logger.info(f"Starting image generation process for a user on the '{plan}' plan.")
full_prompt = f"{prompt}, 4k, high-resolution, photorealistic, masterpiece, single person, solo portrait, centered composition"
negative_prompt = "multiple people, group photo, crowd, two faces, three faces, multiple faces, collage, ugly, deformed, blurry, low quality"
faceid_all_embeds = []
for image_path in face_images:
try:
face = cv2.imread(image_path)
if face is None: continue
faces = self.face_app.get(face)
if faces:
faceid_embed = torch.from_numpy(faces[0].normed_embedding).unsqueeze(0)
faceid_all_embeds.append(faceid_embed)
except Exception as e:
logger.error(f"Error processing face image {image_path}: {e}")
if not faceid_all_embeds:
logger.error("No faces were detected in any of the provided images.")
return None
average_embedding = torch.mean(torch.stack(faceid_all_embeds, dim=0), dim=0)
logger.info("Calling the generation pipeline...")
try:
positive_embedding = average_embedding.unsqueeze(0)
negative_embedding = torch.zeros_like(positive_embedding)
final_embedding = torch.cat([negative_embedding, positive_embedding], dim=0)
output = self.pipe(
prompt=full_prompt, negative_prompt=negative_prompt,
ip_adapter_image_embeds=[final_embedding], num_inference_steps=40,
guidance_scale=7.5, width=512, height=768,
)
image = output.images[0] if isinstance(output, StableDiffusionPipelineOutput) else output[0][0]
temp_dir = "temp_images"
os.makedirs(temp_dir, exist_ok=True)
local_path = os.path.join(temp_dir, f"{uuid.uuid4()}.png")
image.save(local_path)
if plan == 'free':
utils.add_watermark(local_path, "@MagicFaceBot")
else:
self._upscale_image(local_path)
storage_path = f"public/{os.path.basename(local_path)}"
with open(local_path, 'rb') as f:
supabase.storage.from_(config.SUPABASE_BUCKET_NAME).upload(
path=storage_path, file=f, file_options={"content-type": "image/png"}
)
public_url = supabase.storage.from_(config.SUPABASE_BUCKET_NAME).get_public_url(storage_path)
os.remove(local_path)
return public_url
except Exception as e:
logger.error("An unexpected error occurred. Full traceback below:")
traceback.print_exc()
logger.error(f"Error summary: {e}")
if 'local_path' in locals() and os.path.exists(local_path):
os.remove(local_path)
return None
# --- Example Usage (for testing) ---
if __name__ == '__main__':
if os.path.exists("test_face.jpg"):
logger.info("Running a test generation and upload...")
service = GenerationService()
result_url = service.generate_magic_image(
face_images=["test_face.jpg"],
gender="Female",
prompt="A beautiful portrait of a princess in a magical forest, fantasy art",
plan='paid'
)
if result_url:
print(f"\n✅ Test successful! Image URL: {result_url}")
print("Check the image at the URL. It should be high-resolution and have no watermark.")
else:
print(f"\n❌ Test failed. Please check the logs for details.")
else:
print("To run a test, place an image named 'test_face.jpg' in the root directory.")