CarolineM5 commited on
Commit
a6e0f3a
·
verified ·
1 Parent(s): 884ceca

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +27 -26
  2. inference.py +2 -2
app.py CHANGED
@@ -44,31 +44,31 @@ class UNetNoCondWrapper(nn.Module):
44
 
45
  # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
 
47
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
-
49
- model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces"
50
- vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
51
- scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
52
- tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
53
- text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
54
- feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
55
-
56
- # 2) Chargez votre UNet non‑conditionné et wrappez‑le
57
- base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
58
- wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
59
-
60
- # 3) Construisez la pipeline manuellement
61
- pipe = StableDiffusionInstructPix2PixPipeline(
62
- vae=vae,
63
- text_encoder=text_encoder,
64
- tokenizer=tokenizer,
65
- unet=wrapped_unet,
66
- scheduler=scheduler,
67
- safety_checker=None,
68
- feature_extractor=feature_extractor,
69
- )
70
-
71
- pipe = pipe.to(torch.float32).to(device)
72
  # @spaces.GPU
73
 
74
  def pil_to_data_uri(img: Image.Image) -> str:
@@ -122,7 +122,8 @@ def pil_to_data_uri(img: Image.Image) -> str:
122
  def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
123
  try:
124
  # appelle la fonction d'inference de l'utilisateur
125
- outputs = inference(pipe, fibers, rings, int(num_steps))
 
126
  if not (isinstance(outputs, (list,tuple)) and len(outputs) >= 4):
127
  raise ValueError("La fonction d'inference doit renvoyer une liste/tuple de 4 images.")
128
  # Prendre les 4 premières images
 
44
 
45
  # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
46
 
47
+ # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
48
+
49
+ # model_id = "CarolineM5/InstructPix2Pix_WithoutPrompt_4_faces"
50
+ # vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae").to(device)
51
+ # scheduler = DDPMScheduler.from_pretrained(model_id, subfolder="scheduler")
52
+ # tokenizer = CLIPTokenizer.from_pretrained(model_id, subfolder="tokenizer")
53
+ # text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder").to(device)
54
+ # feature_extractor = CLIPImageProcessor.from_pretrained(model_id, subfolder="feature_extractor")
55
+
56
+ # # 2) Chargez votre UNet non‑conditionné et wrappez‑le
57
+ # base_unet = UNet2DModel.from_pretrained(model_id, subfolder="unet").to(device)
58
+ # wrapped_unet = UNetNoCondWrapper(base_unet).to(device)
59
+
60
+ # # 3) Construisez la pipeline manuellement
61
+ # pipe = StableDiffusionInstructPix2PixPipeline(
62
+ # vae=vae,
63
+ # text_encoder=text_encoder,
64
+ # tokenizer=tokenizer,
65
+ # unet=wrapped_unet,
66
+ # scheduler=scheduler,
67
+ # safety_checker=None,
68
+ # feature_extractor=feature_extractor,
69
+ # )
70
+
71
+ # pipe = pipe.to(torch.float32).to(device)
72
  # @spaces.GPU
73
 
74
  def pil_to_data_uri(img: Image.Image) -> str:
 
122
  def run(fibers: Image.Image, rings: Image.Image, num_steps: int):
123
  try:
124
  # appelle la fonction d'inference de l'utilisateur
125
+ # outputs = inference(pipe, fibers, rings, int(num_steps))
126
+ outputs = inference(fibers, rings, int(num_steps))
127
  if not (isinstance(outputs, (list,tuple)) and len(outputs) >= 4):
128
  raise ValueError("La fonction d'inference doit renvoyer une liste/tuple de 4 images.")
129
  # Prendre les 4 premières images
inference.py CHANGED
@@ -20,8 +20,8 @@ def pil_from(x):
20
  if isinstance(x, str):
21
  return PIL.Image.open(x)
22
  return x
23
-
24
- def inference(pipe, fiber_imgs, ring_imgs, num_steps):
25
  """
26
  fiber_imgs: PIL.Image or paths
27
  ring_imgs: PIL.Image or paths
 
20
  if isinstance(x, str):
21
  return PIL.Image.open(x)
22
  return x
23
+ def inference(fiber_imgs, ring_imgs, num_steps):
24
+ # def inference(pipe, fiber_imgs, ring_imgs, num_steps):
25
  """
26
  fiber_imgs: PIL.Image or paths
27
  ring_imgs: PIL.Image or paths