AkashKumarave commited on
Commit
9740894
·
verified ·
1 Parent(s): 8ffd41f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -154
app.py CHANGED
@@ -1,162 +1,51 @@
1
- import cv2
2
- import torch
3
- import numpy as np
4
  import gradio as gr
5
- from diffusers import StableDiffusionPipeline # Use SD 2.1 instead of SDXL
6
- from insightface.app import FaceAnalysis
7
- from huggingface_hub import hf_hub_download
8
- import os
9
- import logging
10
- import time
11
-
12
- # Set up detailed logging
13
- logging.basicConfig(level=logging.INFO)
14
- logger = logging.getLogger(__name__)
15
-
16
- # Allow network access
17
- os.environ["HF_HUB_OFFLINE"] = "0"
18
-
19
- # Set device to CPU
20
- device = "cpu"
21
- dtype = torch.float32
22
-
23
- # Define cache directory
24
- cache_dir = "./cache"
25
- os.makedirs(cache_dir, exist_ok=True)
26
-
27
- # Load face encoder
28
- logger.info("Starting InsightFace initialization...")
29
- try:
30
- face_app = FaceAnalysis(providers=["CPUExecutionProvider"])
31
- face_app.prepare(ctx_id=0, det_size=(480, 480))
32
- logger.info("InsightFace model loaded successfully.")
33
- except Exception as e:
34
- logger.error(f"Failed to load InsightFace model: {e}")
35
- raise
36
-
37
- # Download function with retry logic
38
- def download_file(repo_id, filename, local_dir, max_retries=3):
39
- file_path = os.path.join(local_dir, filename)
40
- if not os.path.exists(file_path):
41
- for attempt in range(max_retries):
42
- logger.info(f"Attempt {attempt + 1}/{max_retries}: Downloading {filename} from {repo_id} to {local_dir}...")
43
- try:
44
- downloaded_path = hf_hub_download(
45
- repo_id=repo_id,
46
- filename=filename,
47
- local_dir=local_dir,
48
- cache_dir=cache_dir,
49
- local_files_only=False
50
- )
51
- logger.info(f"Downloaded to {downloaded_path}")
52
- return downloaded_path
53
- except Exception as e:
54
- logger.error(f"Download attempt {attempt + 1} failed: {e}")
55
- if attempt < max_retries - 1:
56
- logger.info("Retrying in 5 seconds...")
57
- time.sleep(5)
58
- else:
59
- raise RuntimeError(f"Failed to download {filename} after {max_retries} attempts: {e}")
60
- else:
61
- logger.info(f"Using cached file at {file_path}")
62
- return file_path
63
-
64
- # Define paths
65
- ip_adapter_path = "./"
66
- os.makedirs(ip_adapter_path, exist_ok=True)
67
-
68
- # Download IP-Adapter weights with retries
69
- logger.info("Starting weights download...")
70
- ip_adapter_weights = download_file(
71
- "Kwai-Kolors/Kolors-IP-Adapter-FaceID-Plus",
72
- "ipa-faceid-plus.bin",
73
- ip_adapter_path
74
- )
75
-
76
- # Load the pipeline with SD 2.1
77
- logger.info("Loading Stable Diffusion 2.1 base model...")
78
- try:
79
- max_retries = 3
80
- for attempt in range(max_retries):
81
- try:
82
- logger.info(f"Attempt {attempt + 1}/{max_retries}: Loading SD 2.1 model...")
83
- pipe = StableDiffusionPipeline.from_pretrained(
84
- "stabilityai/stable-diffusion-2-1",
85
- torch_dtype=dtype,
86
- safety_checker=None,
87
- local_files_only=False,
88
- cache_dir=cache_dir,
89
- variant="fp16",
90
- use_safetensors=True
91
- )
92
- logger.info("SD 2.1 base model loaded successfully.")
93
- break
94
- except Exception as e:
95
- logger.error(f"Load attempt {attempt + 1} failed: {e}")
96
- if attempt < max_retries - 1:
97
- logger.info("Retrying in 5 seconds...")
98
- time.sleep(5)
99
- else:
100
- raise RuntimeError(f"Failed to load SD 2.1 model after {max_retries} attempts: {e}")
101
- except Exception as e:
102
- logger.error(f"Failed to load SD 2.1 base model: {e}")
103
- raise
104
-
105
- # Load IP-Adapter
106
- logger.info(f"Loading IP-Adapter from {ip_adapter_weights}...")
107
- try:
108
- pipe.load_ip_adapter(ip_adapter_path, subfolder=None, weight_name="ipa-faceid-plus.bin")
109
- logger.info("IP-Adapter loaded successfully.")
110
- except Exception as e:
111
- logger.error(f"Failed to load IP-Adapter: {e}")
112
- raise
113
-
114
- # Move pipeline to CPU
115
- logger.info("Moving pipeline to CPU...")
116
- pipe.to(device)
117
- logger.info("Pipeline moved to CPU.")
118
-
119
- def generate_image(uploaded_image, prompt):
120
- logger.info("Starting image generation...")
121
- try:
122
- img = cv2.cvtColor(np.array(uploaded_image), cv2.COLOR_RGB2BGR)
123
- faces = face_app.get(img)
124
- if not faces:
125
- logger.warning("No face detected in uploaded image.")
126
- return "No face detected!", None
127
-
128
- face_info = faces[-1]
129
- face_emb = face_info["embedding"]
130
-
131
- logger.info(f"Generating image with prompt: {prompt}")
132
- image = pipe(
133
- prompt=prompt,
134
- image_embeds=face_emb,
135
- num_inference_steps=10,
136
- guidance_scale=7.5,
137
- height=256,
138
- width=256
139
- ).images[0]
140
- logger.info("Image generated successfully.")
141
- return "Image generated successfully!", image
142
- except Exception as e:
143
- logger.error(f"Generation failed: {e}")
144
- return f"Generation failed: {e}", None
145
-
146
- # Gradio interface
147
  interface = gr.Interface(
148
  fn=generate_image,
149
  inputs=[
150
- gr.Image(type="pil", label="Upload Reference Image"),
151
- gr.Textbox(label="Enter Prompt", placeholder="e.g., A photorealistic astronaut in space")
152
- ],
153
- outputs=[
154
- gr.Textbox(label="Status"),
155
- gr.Image(label="Generated Image")
156
  ],
157
- title="Face Reference Image Generator",
158
- description="Upload an image with a face and generate a new image."
 
159
  )
160
 
161
- logger.info("Launching Gradio interface...")
162
  interface.launch()
 
 
 
 
1
  import gradio as gr
2
+ from diffusers import StableDiffusionImg2ImgPipeline
3
+ import torch
4
+ from PIL import Image
5
+ from codeformer_app import CodeFormerFaceRestoration
6
+
7
+ # Load models
8
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
9
+ "runwayml/stable-diffusion-v1-5",
10
+ torch_dtype=torch.float16,
11
+ use_safetensors=True
12
+ ).to("cuda")
13
+
14
+ codeformer = CodeFormerFaceRestoration()
15
+
16
+ # Define the image-to-image function
17
+ def generate_image(input_image, prompt, strength, fidelity):
18
+ # Preprocess the input image
19
+ init_image = Image.fromarray(input_image).convert("RGB")
20
+ init_image = init_image.resize((512, 512))
21
+
22
+ # Generate the image
23
+ generated_image = pipe(
24
+ prompt=prompt,
25
+ image=init_image,
26
+ strength=strength,
27
+ guidance_scale=7.5,
28
+ num_inference_steps=50
29
+ ).images[0]
30
+
31
+ # Restore the face
32
+ restored_image = codeformer.restore(generated_image, fidelity=fidelity)
33
+
34
+ return restored_image
35
+
36
+ # Create the Gradio interface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  interface = gr.Interface(
38
  fn=generate_image,
39
  inputs=[
40
+ gr.Image(label="Upload Your Image"), # Image upload input
41
+ gr.Textbox(label="Prompt"), # Text input for the prompt
42
+ gr.Slider(0.1, 1.0, value=0.5, label="Strength (Lower = More Preservation)"), # Strength slider
43
+ gr.Slider(0.1, 1.0, value=0.8, label="Fidelity (Higher = More Preservation)") # Fidelity slider
 
 
44
  ],
45
+ outputs=gr.Image(label="Generated Image"), # Output image
46
+ title="Image-to-Image with Face Preservation",
47
+ description="Upload an image, enter a prompt, and generate a new image while preserving the face."
48
  )
49
 
50
+ # Launch the app
51
  interface.launch()