VcRlAgent commited on
Commit
57407f7
·
1 Parent(s): 24df87c

InstantID For retaining facial identity

Browse files
Files changed (1) hide show
  1. app.py +12 -8
app.py CHANGED
@@ -13,18 +13,21 @@ from rembg import remove
13
  from diffusers import StableDiffusionImg2ImgPipeline
14
  from diffusers import StableDiffusionXLPipeline
15
  import io
16
- import os, sys, subprocess
 
 
 
17
 
18
  # --- Ensure InstantID is available ---
19
  if not os.path.exists("instantid"):
20
  print("🔄 Cloning InstantID repository...")
21
  subprocess.run(["git", "clone", "https://github.com/InstantID/InstantID.git"], check=True)
22
- os.rename("InstantID", "instantid")
23
- sys.path.append(os.path.abspath("instantid"))
24
- else:
25
- sys.path.append(os.path.abspath("instantid"))
26
 
27
- from instantid import InstantID
28
 
29
  import torchvision
30
  print("Printing Torch and TorchVision versions:")
@@ -171,8 +174,9 @@ def create_avatar(img: Image.Image, prompt: str, strength: float, guidance_scale
171
  torch_dtype=torch.float16
172
  ).to(device)
173
 
174
- instantid = InstantID.from_pretrained("InstantID/InstantID")
175
- pipe.load_ip_adapter(instantid)
 
176
 
177
  # --- Step 2: Optimize for ZeroGPU memory ---
178
  pipe.enable_attention_slicing()
 
13
  from diffusers import StableDiffusionImg2ImgPipeline
14
  from diffusers import StableDiffusionXLPipeline
15
  import io
16
+ import os, sys, subprocess, warnings
17
+
18
+ warnings.filterwarnings("ignore", category=UserWarning)
19
+ os.environ["CUDA_VISIBLE_DEVICES"] = ""
20
 
21
  # --- Ensure InstantID is available ---
22
  if not os.path.exists("instantid"):
23
  print("🔄 Cloning InstantID repository...")
24
  subprocess.run(["git", "clone", "https://github.com/InstantID/InstantID.git"], check=True)
25
+ if os.path.exists("InstantID") and not os.path.exists("instantid"):
26
+ os.rename("InstantID", "instantid")
27
+
28
+ sys.path.append(os.path.abspath("instantid"))
29
 
30
+ from pipelines.pipeline_instantid import InstantIDPipeline
31
 
32
  import torchvision
33
  print("Printing Torch and TorchVision versions:")
 
174
  torch_dtype=torch.float16
175
  ).to(device)
176
 
177
+ instantid = InstantIDPipeline.from_pretrained("InstantID/InstantID", torch_dtype=torch.float16,)
178
+ pipe.to("cuda" if torch.cuda.is_available() else "cpu")
179
+ #pipe.load_ip_adapter(instantid)
180
 
181
  # --- Step 2: Optimize for ZeroGPU memory ---
182
  pipe.enable_attention_slicing()