MagicQuilledit / app.py
K1Z3M1112's picture
Update app.py
4f05a63 verified
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
os.environ["NO_CUDA"] = "1"
os.environ["USE_CPU"] = "1"
import subprocess
import shlex
subprocess.run(
shlex.split(
"pip install ./gradio_magicquill-0.0.1-py3-none-any.whl"
)
)
# แก้ไขไฟล์ model_management.py เพื่อบังคับให้ใช้ CPU
model_management_path = "/usr/local/lib/python3.10/site-packages/MagicQuill/comfy/model_management.py"
if os.path.exists(model_management_path):
with open(model_management_path, 'r') as f:
content = f.read()
# แก้ไขฟังก์ชัน get_torch_device() ให้คืนค่า CPU เสมอ
if 'def get_torch_device():' in content:
new_content = content.replace(
'def get_torch_device():',
'''def get_torch_device():
import torch
return torch.device("cpu")'''
)
with open(model_management_path, 'w') as f:
f.write(new_content)
print("Fixed model_management.py to use CPU only")
else:
# ลองหาไฟล์ในตำแหน่งอื่น
import sys
for path in sys.path:
test_path = os.path.join(path, 'MagicQuill', 'comfy', 'model_management.py')
if os.path.exists(test_path):
with open(test_path, 'r') as f:
content = f.read()
# แก้ไขฟังก์ชัน get_torch_device() ให้คืนค่า CPU เสมอ
if 'def get_torch_device():' in content:
new_content = content.replace(
'def get_torch_device():',
'''def get_torch_device():
import torch
return torch.device("cpu")'''
)
with open(test_path, 'w') as f:
f.write(new_content)
print(f"Fixed model_management.py at {test_path} to use CPU only")
break
import gradio as gr
from gradio_magicquill import MagicQuill
import random
import torch
torch.set_default_device('cpu')
torch.backends.cudnn.enabled = False
import numpy as np
from PIL import Image, ImageOps
import base64
import io
from fastapi import FastAPI, Request
import uvicorn
# ตั้งค่าก่อนนำเข้า MagicQuill modules
os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = '1'
os.environ['PYTORCH_MPS_HIGH_WATERMARK_RATIO'] = '0.0'
try:
from MagicQuill.scribble_color_edit import ScribbleColorEditModel
except ImportError as e:
print(f"Error importing ScribbleColorEditModel: {e}")
# ลองแก้ไขเพิ่มเติมถ้ายังไม่สามารถนำเข้าได้
import traceback
traceback.print_exc()
# สร้าง fake class เพื่อให้สามารถรันต่อได้
class ScribbleColorEditModel:
def process(self, *args, **kwargs):
raise NotImplementedError("ScribbleColorEditModel is not available in CPU mode")
from gradio_client import Client, handle_file
from huggingface_hub import snapshot_download
import tempfile
import cv2
import requests
import gc
snapshot_download(repo_id="LiuZichen/MagicQuill-models", repo_type="model", local_dir="models")
client = Client("LiuZichen/MagicQuillHelper")
scribbleColorEditModel = ScribbleColorEditModel()
def tensor_to_numpy(tensor):
if isinstance(tensor, torch.Tensor):
return (tensor.detach().cpu().numpy() * 255).astype(np.uint8)
return tensor
def tensor_to_base64(tensor):
tensor = tensor.squeeze(0) * 255.
pil_image = Image.fromarray(tensor.cpu().byte().numpy())
buffered = io.BytesIO()
pil_image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return img_str
def read_base64_image(base64_image):
if base64_image.startswith("data:image/png;base64,"):
base64_image = base64_image.split(",")[1]
elif base64_image.startswith("data:image/jpeg;base64,"):
base64_image = base64_image.split(",")[1]
elif base64_image.startswith("data:image/webp;base64,"):
base64_image = base64_image.split(",")[1]
else:
raise ValueError("Unsupported image format.")
image_data = base64.b64decode(base64_image)
image = Image.open(io.BytesIO(image_data))
image = ImageOps.exif_transpose(image)
return image
def create_alpha_mask(base64_image):
"""Create an alpha mask from the alpha channel of an image."""
image = read_base64_image(base64_image)
mask = torch.zeros((1, image.height, image.width), dtype=torch.float32, device="cpu")
if 'A' in image.getbands():
alpha_channel = np.array(image.getchannel('A')).astype(np.float32) / 255.0
mask[0] = 1.0 - torch.from_numpy(alpha_channel)
return mask
def load_and_preprocess_image(base64_image, convert_to='RGB', has_alpha=False):
"""Load and preprocess a base64 image."""
image = read_base64_image(base64_image)
image = image.convert(convert_to)
image_array = np.array(image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image_array)[None,]
return image_tensor
def load_and_resize_image(base64_image, convert_to='RGB', max_size=512):
"""Load and preprocess a base64 image, resize if necessary."""
image = read_base64_image(base64_image)
image = image.convert(convert_to)
width, height = image.size
scaling_factor = max_size / min(width, height)
new_size = (int(width * scaling_factor), int(height * scaling_factor))
image = image.resize(new_size, Image.LANCZOS)
image_array = np.array(image).astype(np.float32) / 255.0
image_tensor = torch.from_numpy(image_array)[None,]
return image_tensor
def prepare_images_and_masks(total_mask, original_image, add_color_image, add_edge_image, remove_edge_image):
total_mask = create_alpha_mask(total_mask)
original_image_tensor = load_and_preprocess_image(original_image)
if add_color_image:
add_color_image_tensor = load_and_preprocess_image(add_color_image)
else:
add_color_image_tensor = original_image_tensor
add_edge_mask = create_alpha_mask(add_edge_image) if add_edge_image else torch.zeros_like(total_mask)
remove_edge_mask = create_alpha_mask(remove_edge_image) if remove_edge_image else torch.zeros_like(total_mask)
return add_color_image_tensor, original_image_tensor, total_mask, add_edge_mask, remove_edge_mask
def guess_prompt_handler(original_image, add_color_image, add_edge_image):
original_image_tensor = load_and_preprocess_image(original_image)
if add_color_image:
add_color_image_tensor = load_and_preprocess_image(add_color_image)
else:
add_color_image_tensor = original_image_tensor
width, height = original_image_tensor.shape[1], original_image_tensor.shape[2]
add_edge_mask = create_alpha_mask(add_edge_image) if add_edge_image else torch.zeros((1, height, width), dtype=torch.float32, device="cpu")
original_image_numpy = tensor_to_numpy(original_image_tensor.squeeze(0))
add_color_image_numpy = tensor_to_numpy(add_color_image_tensor.squeeze(0))
add_edge_mask_numpy = tensor_to_numpy(add_edge_mask.squeeze(0).unsqueeze(-1))
original_image_numpy = cv2.cvtColor(original_image_numpy, cv2.COLOR_RGB2BGR)
add_color_image_numpy = cv2.cvtColor(add_color_image_numpy, cv2.COLOR_RGB2BGR)
original_image_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b')
add_color_image_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b')
add_edge_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".png", mode='w+b')
cv2.imwrite(original_image_file.name, original_image_numpy)
cv2.imwrite(add_color_image_file.name, add_color_image_numpy)
cv2.imwrite(add_edge_mask_file.name, add_edge_mask_numpy)
original_image_file.close()
add_color_image_file.close()
add_edge_mask_file.close()
res = client.predict(
handle_file(original_image_file.name),
handle_file(add_color_image_file.name),
handle_file(add_edge_mask_file.name),
api_name="/guess_prompt"
)
if original_image_file and os.path.exists(original_image_file.name):
os.remove(original_image_file.name)
if add_color_image_file and os.path.exists(add_color_image_file.name):
os.remove(add_color_image_file.name)
if add_edge_mask_file and os.path.exists(add_edge_mask_file.name):
os.remove(add_edge_mask_file.name)
return res
def generate(ckpt_name, total_mask, original_image, add_color_image, add_edge_image, remove_edge_image, positive_prompt, negative_prompt, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler):
add_color_image, original_image, total_mask, add_edge_mask, remove_edge_mask = prepare_images_and_masks(total_mask, original_image, add_color_image, add_edge_image, remove_edge_image)
progress = None
if fine_edge == 'disable':
if torch.sum(remove_edge_mask).item() > 0 and torch.sum(add_edge_mask).item() == 0:
if positive_prompt == "":
positive_prompt = "empty scene"
edge_strength /= 3.
latent_samples, final_image, lineart_output, color_output = scribbleColorEditModel.process(
ckpt_name,
original_image,
add_color_image,
positive_prompt,
negative_prompt,
total_mask,
add_edge_mask,
remove_edge_mask,
grow_size,
stroke_as_edge,
fine_edge,
edge_strength,
color_strength,
inpaint_strength,
seed,
steps,
cfg,
sampler_name,
scheduler,
progress
)
final_image_base64 = tensor_to_base64(final_image)
del latent_samples, final_image, lineart_output, color_output
gc.collect()
return final_image_base64
def generate_image_handler(x, ckpt_name, negative_prompt, fine_edge, grow_size, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler):
if seed == -1:
seed = random.randint(0, 2**32 - 1)
ms_data = x['from_frontend']
positive_prompt = x['from_backend']['prompt']
stroke_as_edge = "enable"
res = generate(ckpt_name, ms_data['total_mask'], ms_data['original_image'], ms_data['add_color_image'], ms_data['add_edge_image'], ms_data['remove_edge_image'], positive_prompt, negative_prompt, grow_size, stroke_as_edge, fine_edge, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler)
x["from_backend"]["generated_image"] = res
return x
css = '''
.row {
width: 90%;
margin: auto;
}
'''
head = """
<meta http-equiv="Content-Security-Policy" content="frame-ancestors 'none'">
"""
with gr.Blocks(css=css, head=head) as demo:
with gr.Row(elem_classes="row"):
text = gr.Markdown(
"""
# Welcome to MagicQuill! The paper has been accepted to CVPR 2025.
Click the [link](https://magicquill.art) to view our demo and tutorial. Give us a [GitHub star](https://github.com/magic-quill/magicquill) if you are interested.
MagicQuillV2 is available!!! Check our [demo](https://magicquill.art/v2/).
""")
with gr.Row(elem_classes="row"):
ms = MagicQuill()
with gr.Row(elem_classes="row"):
with gr.Column():
btn = gr.Button("Run", variant="primary")
with gr.Column():
with gr.Accordion("parameters", open=False):
ckpt_value = os.path.join('SD1.5', 'realisticVisionV60B1_v51VAE.safetensors')
ckpt_name = gr.Dropdown(
label="Base Model (fixed for demo)",
choices=[ckpt_value],
value=ckpt_value,
interactive=False
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value="",
interactive=True
)
fine_edge = gr.Radio(
label="Fine Edge",
choices=['enable', 'disable'],
value='disable',
interactive=True
)
grow_size = gr.Slider(
label="Grow Size",
minimum=0,
maximum=100,
value=15,
step=1,
interactive=True
)
edge_strength = gr.Slider(
label="Edge Strength",
minimum=0.0,
maximum=5.0,
value=0.55,
step=0.01,
interactive=True
)
color_strength = gr.Slider(
label="Color Strength",
minimum=0.0,
maximum=5.0,
value=0.55,
step=0.01,
interactive=True
)
inpaint_strength = gr.Slider(
label="Inpaint Strength",
minimum=0.0,
maximum=5.0,
value=1.0,
step=0.01,
interactive=True
)
seed = gr.Number(
label="Seed",
value=-1,
precision=0,
interactive=True
)
steps = gr.Slider(
label="Steps",
minimum=1,
maximum=50,
value=20,
step=1,
interactive=True
)
cfg = gr.Slider(
label="CFG",
minimum=0.0,
maximum=20.0,
value=5.0,
step=0.1,
interactive=True
)
sampler_name = gr.Dropdown(
label="Sampler Name",
choices=["euler", "euler_ancestral", "heun", "heunpp2","dpm_2", "dpm_2_ancestral", "lms", "dpm_fast", "dpm_adaptive", "dpmpp_2s_ancestral", "dpmpp_sde", "dpmpp_sde_gpu", "dpmpp_2m", "dpmpp_2m_sde", "dpmpp_2m_sde_gpu", "dpmpp_3m_sde", "dpmpp_3m_sde_gpu", "ddpm", "lcm", "ddim", "uni_pc", "uni_pc_bh2"],
value='euler_ancestral',
interactive=True
)
scheduler = gr.Dropdown(
label="Scheduler",
choices=["normal", "karras", "exponential", "sgm_uniform", "simple", "ddim_uniform"],
value='karras',
interactive=True
)
btn.click(generate_image_handler, inputs=[ms, ckpt_name, negative_prompt, fine_edge, grow_size, edge_strength, color_strength, inpaint_strength, seed, steps, cfg, sampler_name, scheduler], outputs=ms, concurrency_limit=1)
with gr.Row(elem_classes="row"):
text = gr.Markdown(
"""
Note: This demo is governed by the license of CC BY-NC 4.0. We strongly advise users not to knowingly generate or allow others to knowingly generate harmful content, including hate speech, violence, pornography, deception, etc. (注:本演示受CC BY-NC的许可协议限制。我们强烈建议,用户不应传播及不应允许他人传播以下内容,包括但不限于仇恨言论、暴力、色情、欺诈相关的有害信息。)
""")
demo.queue(max_size=20, status_update_rate=0.1)
app = FastAPI()
@app.post("/magic_quill/guess_prompt")
async def guess_prompt(request: Request):
data = await request.json()
res = guess_prompt_handler(data['original_image'], data['add_color_image'], data['add_edge_image'])
return res
@app.post("/magic_quill/process_background_img")
async def process_background_img(request: Request):
img = await request.json()
resized_img_tensor = load_and_resize_image(img)
resized_img_base64 = "data:image/png;base64," + tensor_to_base64(resized_img_tensor)
return resized_img_base64
app = gr.mount_gradio_app(app, demo, "/")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)