AkashKumarave commited on
Commit
9de5dbc
·
verified ·
1 Parent(s): 8cf7a41

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import torch
3
+ import numpy as np
4
+ import gradio as gr
5
+ from diffusers import StableDiffusionXLPipeline, ControlNetModel
6
+ from insightface.app import FaceAnalysis
7
+ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
8
+ import os
9
+
10
+ # Force offline mode to avoid runtime Hub connections
11
+ os.environ["HF_HUB_OFFLINE"] = "1"
12
+
13
+ # Set device to CPU (free tier has no GPU)
14
+ device = "cpu"
15
+ dtype = torch.float32
16
+
17
+ # Load face encoder
18
+ face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
19
+ face_app.prepare(ctx_id=0, det_size=(480, 480))
20
+
21
+ # Define paths for preloaded weights
22
+ controlnet_path = "./ControlNetModel"
23
+ face_adapter_path = "./ip-adapter.bin"
24
+
25
+ # Check if files exist
26
+ if not os.path.exists(controlnet_path) or not os.path.exists(os.path.join(controlnet_path, "config.json")):
27
+ raise FileNotFoundError(f"ControlNetModel directory or config.json not found at {controlnet_path}")
28
+ if not os.path.exists(face_adapter_path):
29
+ raise FileNotFoundError(f"ip-adapter.bin not found at {face_adapter_path}")
30
+
31
+ # Initialize models with empty weights
32
+ with init_empty_weights():
33
+ controlnet = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=dtype)
34
+ pipe = StableDiffusionXLPipeline.from_pretrained(
35
+ "stabilityai/stable-diffusion-xl-base-1.0",
36
+ controlnet=controlnet,
37
+ torch_dtype=dtype,
38
+ safety_checker=None,
39
+ )
40
+
41
+ # Load and dispatch models with accelerate
42
+ controlnet = load_checkpoint_and_dispatch(controlnet, controlnet_path, device_map="cpu", offload_folder=None)
43
+ pipe = load_checkpoint_and_dispatch(pipe, "./", device_map="cpu", offload_folder=None)
44
+ pipe.load_ip_adapter_instantid(face_adapter_path)
45
+
46
+ def generate_image(uploaded_image, prompt):
47
+ # Convert Gradio image to OpenCV format
48
+ img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
49
+ faces = face_app.get(img)
50
+ if not faces:
51
+ return "No face detected!", None
52
+
53
+ face_info = faces[-1] # Use largest face
54
+ face_emb = face_info["embedding"]
55
+
56
+ try:
57
+ image = pipe(
58
+ prompt=prompt,
59
+ image_embeds=face_emb,
60
+ num_inference_steps=20,
61
+ guidance_scale=7.5,
62
+ height=512,
63
+ width=512,
64
+ controlnet_conditioning_scale=1.0,
65
+ ).images[0]
66
+ return "Image generated successfully!", image
67
+ except Exception as e:
68
+ return f"Generation failed: {e}", None
69
+
70
+ # Gradio interface
71
+ interface = gr.Interface(
72
+ fn=generate_image,
73
+ inputs=[
74
+ gr.Image(type="pil", label="Upload Reference Image"),
75
+ gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")
76
+ ],
77
+ outputs=[
78
+ gr.Textbox(label="Status"),
79
+ gr.Image(label="Generated Image")
80
+ ],
81
+ title="Face Reference Image Generator",
82
+ description="Upload an image with a face, enter a prompt, and generate a new image preserving the reference face."
83
+ )
84
+
85
+ interface.launch()