Spaces:
Sleeping
Sleeping
| import cv2 | |
| import torch | |
| from transformers import Swin2SRImageProcessor, Swin2SRForImageSuperResolution | |
| from diffusers import StableDiffusionPipeline | |
| import numpy as np | |
| from PIL import Image, ImageEnhance, ImageOps | |
| import random | |
| # from safetensors.torch import load_file | |
| from stable_diffusion import MiniDiffusionPipeline | |
| # --- Cấu hình --- | |
| #PROMPT = "beautiful woman with long braided hair, wearing a scarf, soft smile, looking down, detailed shading" #725562173 | |
| #PROMPT = "attractive woman, big lips, mouth slightly open, heavy makeup" #v5 | |
| #PROMPT = "The man is young and has sharp jawline, narrow eyes, thick eyebrows, and short black hair." #10, 11 | |
| #PROMPT = "She is elderly with deep smile lines, small eyes, and short curly gray hair." #13 | |
| #PROMPT = "This man is old and smiling, with gray beard and big nose" | |
| PROMPT = "a baby" | |
| SAVE_IMAGE_PATH = "./15.png" | |
| UNET_SAFE_PATH = "./unet-mini.safetensors" | |
| VAE_SAFE_PATH = "./vae-finetuned.safetensors" | |
| BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| TINY_UNET_CONFIG = { | |
| "unet_block_out_channels": (128, 256, 512), | |
| } | |
| MODEL_ID = "caidas/swin2SR-classical-sr-x4-64" | |
| print(f"Đang load model {MODEL_ID} từ Hugging Face...") | |
| processor = Swin2SRImageProcessor.from_pretrained(MODEL_ID) | |
| model = Swin2SRForImageSuperResolution.from_pretrained(MODEL_ID) | |
| print("Load model thành công!") | |
| model = model.to(DEVICE) | |
| def upscale_image_pipeline(pil_image, contrast=1.3, sharpen=1.5, target_size=(512, 512)): | |
| """ | |
| Chiến thuật "Canvas Isolation" (Cách ly khung tranh). | |
| Đặt ảnh vào giữa một vùng trắng cực rộng để đẩy lỗi biên ra xa. | |
| """ | |
| if model is None or processor is None: | |
| return pil_image.resize(target_size) | |
| # 1. Chuẩn bị ảnh | |
| img_np = np.array(pil_image) | |
| if len(img_np.shape) == 2: | |
| img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB) | |
| h_orig, w_orig = img_np.shape[:2] | |
| # 2. TẠO CANVAS (Khung tranh) LỚN | |
| # Tạo một nền trắng to gấp đôi ảnh gốc (256x256) | |
| # Mục đích: Đưa ảnh thật vào "vùng an toàn" ở trung tâm tuyệt đối | |
| canvas_size = 256 | |
| canvas = np.ones((canvas_size, canvas_size, 3), dtype=np.uint8) * 255 | |
| # Tính tọa độ để dán ảnh vào giữa | |
| y_offset = (canvas_size - h_orig) // 2 # (256-128)/2 = 64 | |
| x_offset = (canvas_size - w_orig) // 2 # 64 | |
| # Dán ảnh vào canvas | |
| canvas[y_offset:y_offset+h_orig, x_offset:x_offset+w_orig] = img_np | |
| # 3. Upscale toàn bộ Canvas | |
| # Lúc này model sẽ xử lý biên của ảnh 256x256 -> Lỗi phản chiếu sẽ nằm ở rìa canvas (cách ảnh thật rất xa) | |
| pil_canvas = Image.fromarray(canvas) | |
| inputs = processor(pil_canvas, return_tensors="pt").to(DEVICE) | |
| with torch.no_grad(): | |
| outputs = model(**inputs) | |
| output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy() | |
| output_tensor = np.moveaxis(output_tensor, 0, -1) | |
| output_canvas = (output_tensor * 255.0).round().astype(np.uint8) | |
| # 4. TRÍCH XUẤT ẢNH THẬT (CROP) | |
| # Canvas input 256 -> Upscale x4 -> Canvas output 1024 | |
| # Ảnh thật nằm ở vị trí offset * 4 | |
| scale_factor = 4 | |
| y_start = y_offset * scale_factor # 64 * 4 = 256 | |
| x_start = x_offset * scale_factor # 256 | |
| # Kích thước ảnh thật sau khi upscale (128 * 4 = 512) | |
| h_real = h_orig * scale_factor | |
| w_real = w_orig * scale_factor | |
| # Cắt lấy đúng phần ảnh thật nằm giữa canvas | |
| final_img = output_canvas[y_start : y_start + h_real, x_start : x_start + w_real] | |
| # 5. BIỆN PHÁP CƯỠNG BỨC (HARD FIX) | |
| # Nếu model vẫn "lì lợm" tạo ra 1-2 pixel mờ ở đáy, ta sẽ tô trắng 3 dòng pixel cuối cùng. | |
| # Vì đây là tranh vẽ trên nền trắng, việc này không ảnh hưởng nội dung nhưng xóa sạch mọi lỗi. | |
| final_img[-1:, :, :] = 255 | |
| final_img[:, -1:, :] = 255 | |
| # 6. Đảm bảo kích thước cuối cùng | |
| if final_img.shape[:2] != target_size: | |
| final_img = cv2.resize(final_img, (target_size[1], target_size[0]), interpolation=cv2.INTER_LANCZOS4) | |
| # 7. Hậu xử lý | |
| final_pil = Image.fromarray(final_img) | |
| enhancer = ImageEnhance.Contrast(final_pil) | |
| final_pil = enhancer.enhance(contrast) | |
| enhancer = ImageEnhance.Sharpness(final_pil) | |
| final_pil = enhancer.enhance(sharpen) | |
| return final_pil | |
| def main(): | |
| print("--- Bắt đầu quá trình Inference (từ Safetensors) ---") | |
| # --- Khởi tạo MiniDiffusionPipeline --- | |
| print(f"Đang tải pipeline gốc từ {BASE_MODEL_ID}...") | |
| container = MiniDiffusionPipeline( | |
| base_model_id=BASE_MODEL_ID, | |
| device=DEVICE, | |
| config_overrides=TINY_UNET_CONFIG | |
| ) | |
| # --- Tải trọng số đã huấn luyện --- | |
| # Tải UNet | |
| print(f"Đang tải trọng số UNet từ {UNET_SAFE_PATH}...") | |
| try: | |
| unet_weights = torch.load(UNET_SAFE_PATH, map_location=DEVICE) | |
| container.unet.load_state_dict(unet_weights) | |
| except Exception as e: | |
| print(f"LỖI: Không thể tải UNet state dict: {e}") | |
| print("Kiểm tra xem bạn đã bỏ chú thích 'config_overrides=TINY_UNET_CONFIG' chưa?") | |
| return | |
| # Tải VAE | |
| print(f"Đang tải trọng số VAE từ {VAE_SAFE_PATH}...") | |
| try: | |
| vae_weights = torch.load(VAE_SAFE_PATH, map_location=DEVICE) | |
| container.vae.load_state_dict(vae_weights) | |
| except Exception as e: | |
| print(f"LỖI: Không thể tải VAE state dict: {e}") | |
| return | |
| # --- Khởi tạo StableDiffusionPipeline --- | |
| torch_dtype = torch.float16 if DEVICE.startswith("cuda") else torch.float32 | |
| print("Đang tạo pipeline inference...") | |
| inference_pipeline = StableDiffusionPipeline( | |
| unet=container.unet, | |
| vae=container.vae, | |
| text_encoder=container.text_encoder, | |
| tokenizer=container.tokenizer, | |
| scheduler=container.noise_scheduler, | |
| safety_checker=None, | |
| feature_extractor=None, | |
| ).to(DEVICE) | |
| if DEVICE.startswith("cuda"): | |
| inference_pipeline.to(dtype=torch_dtype) | |
| inference_pipeline.set_progress_bar_config(disable=False) | |
| # --- Tạo ảnh --- | |
| print(f"\nĐang tạo ảnh cho prompt: '{PROMPT}'") | |
| current_seed = random.randint(0, 2**32 - 1) | |
| print(f"Seed hiện tại: {current_seed}") | |
| generator = torch.Generator(device=DEVICE).manual_seed(current_seed) #725562173, 4169604779, 725562172, 3884820838, 1794046812, 1379970385 | |
| image = inference_pipeline( | |
| prompt=PROMPT, | |
| num_inference_steps=50, | |
| generator=generator, | |
| guidance_scale=7.5 | |
| ).images[0] | |
| final_image = upscale_image_pipeline(image) | |
| final_image.save(SAVE_IMAGE_PATH) | |
| # --- Lưu ảnh --- | |
| image.save(SAVE_IMAGE_PATH.replace(".png", "_original.png")) | |
| # # --- Lưu ảnh --- | |
| # image.save(SAVE_IMAGE_PATH) | |
| print(f"\n--- Hoàn thành! ---") | |
| print(f"Đã lưu ảnh tại: {SAVE_IMAGE_PATH}") | |
| try: | |
| image.show() | |
| except Exception: | |
| pass | |
| if __name__ == "__main__": | |
| main() |