NguyenThanh1405's picture
Deploy CQL Chatbot (without large files)
4cfe4fa
import cv2
import torch
from transformers import Swin2SRImageProcessor, Swin2SRForImageSuperResolution
from diffusers import StableDiffusionPipeline
import numpy as np
from PIL import Image, ImageEnhance, ImageOps
import random
# from safetensors.torch import load_file
from stable_diffusion import MiniDiffusionPipeline
# --- Cấu hình ---
#PROMPT = "beautiful woman with long braided hair, wearing a scarf, soft smile, looking down, detailed shading" #725562173
#PROMPT = "attractive woman, big lips, mouth slightly open, heavy makeup" #v5
#PROMPT = "The man is young and has sharp jawline, narrow eyes, thick eyebrows, and short black hair." #10, 11
#PROMPT = "She is elderly with deep smile lines, small eyes, and short curly gray hair." #13
#PROMPT = "This man is old and smiling, with gray beard and big nose"
PROMPT = "a baby"
SAVE_IMAGE_PATH = "./15.png"
UNET_SAFE_PATH = "./unet-mini.safetensors"
VAE_SAFE_PATH = "./vae-finetuned.safetensors"
BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TINY_UNET_CONFIG = {
"unet_block_out_channels": (128, 256, 512),
}
MODEL_ID = "caidas/swin2SR-classical-sr-x4-64"
print(f"Đang load model {MODEL_ID} từ Hugging Face...")
processor = Swin2SRImageProcessor.from_pretrained(MODEL_ID)
model = Swin2SRForImageSuperResolution.from_pretrained(MODEL_ID)
print("Load model thành công!")
model = model.to(DEVICE)
def upscale_image_pipeline(pil_image, contrast=1.3, sharpen=1.5, target_size=(512, 512)):
"""
Chiến thuật "Canvas Isolation" (Cách ly khung tranh).
Đặt ảnh vào giữa một vùng trắng cực rộng để đẩy lỗi biên ra xa.
"""
if model is None or processor is None:
return pil_image.resize(target_size)
# 1. Chuẩn bị ảnh
img_np = np.array(pil_image)
if len(img_np.shape) == 2:
img_np = cv2.cvtColor(img_np, cv2.COLOR_GRAY2RGB)
h_orig, w_orig = img_np.shape[:2]
# 2. TẠO CANVAS (Khung tranh) LỚN
# Tạo một nền trắng to gấp đôi ảnh gốc (256x256)
# Mục đích: Đưa ảnh thật vào "vùng an toàn" ở trung tâm tuyệt đối
canvas_size = 256
canvas = np.ones((canvas_size, canvas_size, 3), dtype=np.uint8) * 255
# Tính tọa độ để dán ảnh vào giữa
y_offset = (canvas_size - h_orig) // 2 # (256-128)/2 = 64
x_offset = (canvas_size - w_orig) // 2 # 64
# Dán ảnh vào canvas
canvas[y_offset:y_offset+h_orig, x_offset:x_offset+w_orig] = img_np
# 3. Upscale toàn bộ Canvas
# Lúc này model sẽ xử lý biên của ảnh 256x256 -> Lỗi phản chiếu sẽ nằm ở rìa canvas (cách ảnh thật rất xa)
pil_canvas = Image.fromarray(canvas)
inputs = processor(pil_canvas, return_tensors="pt").to(DEVICE)
with torch.no_grad():
outputs = model(**inputs)
output_tensor = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output_tensor = np.moveaxis(output_tensor, 0, -1)
output_canvas = (output_tensor * 255.0).round().astype(np.uint8)
# 4. TRÍCH XUẤT ẢNH THẬT (CROP)
# Canvas input 256 -> Upscale x4 -> Canvas output 1024
# Ảnh thật nằm ở vị trí offset * 4
scale_factor = 4
y_start = y_offset * scale_factor # 64 * 4 = 256
x_start = x_offset * scale_factor # 256
# Kích thước ảnh thật sau khi upscale (128 * 4 = 512)
h_real = h_orig * scale_factor
w_real = w_orig * scale_factor
# Cắt lấy đúng phần ảnh thật nằm giữa canvas
final_img = output_canvas[y_start : y_start + h_real, x_start : x_start + w_real]
# 5. BIỆN PHÁP CƯỠNG BỨC (HARD FIX)
# Nếu model vẫn "lì lợm" tạo ra 1-2 pixel mờ ở đáy, ta sẽ tô trắng 3 dòng pixel cuối cùng.
# Vì đây là tranh vẽ trên nền trắng, việc này không ảnh hưởng nội dung nhưng xóa sạch mọi lỗi.
final_img[-1:, :, :] = 255
final_img[:, -1:, :] = 255
# 6. Đảm bảo kích thước cuối cùng
if final_img.shape[:2] != target_size:
final_img = cv2.resize(final_img, (target_size[1], target_size[0]), interpolation=cv2.INTER_LANCZOS4)
# 7. Hậu xử lý
final_pil = Image.fromarray(final_img)
enhancer = ImageEnhance.Contrast(final_pil)
final_pil = enhancer.enhance(contrast)
enhancer = ImageEnhance.Sharpness(final_pil)
final_pil = enhancer.enhance(sharpen)
return final_pil
@torch.no_grad()
def main():
print("--- Bắt đầu quá trình Inference (từ Safetensors) ---")
# --- Khởi tạo MiniDiffusionPipeline ---
print(f"Đang tải pipeline gốc từ {BASE_MODEL_ID}...")
container = MiniDiffusionPipeline(
base_model_id=BASE_MODEL_ID,
device=DEVICE,
config_overrides=TINY_UNET_CONFIG
)
# --- Tải trọng số đã huấn luyện ---
# Tải UNet
print(f"Đang tải trọng số UNet từ {UNET_SAFE_PATH}...")
try:
unet_weights = torch.load(UNET_SAFE_PATH, map_location=DEVICE)
container.unet.load_state_dict(unet_weights)
except Exception as e:
print(f"LỖI: Không thể tải UNet state dict: {e}")
print("Kiểm tra xem bạn đã bỏ chú thích 'config_overrides=TINY_UNET_CONFIG' chưa?")
return
# Tải VAE
print(f"Đang tải trọng số VAE từ {VAE_SAFE_PATH}...")
try:
vae_weights = torch.load(VAE_SAFE_PATH, map_location=DEVICE)
container.vae.load_state_dict(vae_weights)
except Exception as e:
print(f"LỖI: Không thể tải VAE state dict: {e}")
return
# --- Khởi tạo StableDiffusionPipeline ---
torch_dtype = torch.float16 if DEVICE.startswith("cuda") else torch.float32
print("Đang tạo pipeline inference...")
inference_pipeline = StableDiffusionPipeline(
unet=container.unet,
vae=container.vae,
text_encoder=container.text_encoder,
tokenizer=container.tokenizer,
scheduler=container.noise_scheduler,
safety_checker=None,
feature_extractor=None,
).to(DEVICE)
if DEVICE.startswith("cuda"):
inference_pipeline.to(dtype=torch_dtype)
inference_pipeline.set_progress_bar_config(disable=False)
# --- Tạo ảnh ---
print(f"\nĐang tạo ảnh cho prompt: '{PROMPT}'")
current_seed = random.randint(0, 2**32 - 1)
print(f"Seed hiện tại: {current_seed}")
generator = torch.Generator(device=DEVICE).manual_seed(current_seed) #725562173, 4169604779, 725562172, 3884820838, 1794046812, 1379970385
image = inference_pipeline(
prompt=PROMPT,
num_inference_steps=50,
generator=generator,
guidance_scale=7.5
).images[0]
final_image = upscale_image_pipeline(image)
final_image.save(SAVE_IMAGE_PATH)
# --- Lưu ảnh ---
image.save(SAVE_IMAGE_PATH.replace(".png", "_original.png"))
# # --- Lưu ảnh ---
# image.save(SAVE_IMAGE_PATH)
print(f"\n--- Hoàn thành! ---")
print(f"Đã lưu ảnh tại: {SAVE_IMAGE_PATH}")
try:
image.show()
except Exception:
pass
if __name__ == "__main__":
main()