OsamaAbdeljaber CarolineM5 commited on
Commit
65e8581
·
verified ·
1 Parent(s): 0482213

Upload inference.py (#2)

Browse files

- Upload inference.py (0352db982029a1d866eb5119cee262074b61fee9)


Co-authored-by: Caroline Marc <CarolineM5@users.noreply.huggingface.co>

Files changed (1) hide show
  1. inference.py +107 -71
inference.py CHANGED
@@ -13,79 +13,115 @@ import torch.nn as nn
13
  from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
  from PIL import Image
15
  import random
 
16
 
17
- class UNetNoCondWrapper(nn.Module):
18
- def __init__(self, base_unet: UNet2DModel):
19
- super().__init__()
20
- self.unet = base_unet
21
-
22
- def forward(
23
- self,
24
- sample,
25
- timestep,
26
- encoder_hidden_states=None,
27
- added_cond_kwargs=None,
28
- cross_attention_kwargs=None,
29
- return_dict=False,
30
- **kwargs
31
- ):
32
-
33
- return self.unet(sample, timestep, return_dict=return_dict, **kwargs)
34
-
35
- def __getattr__(self, name):
36
- if name in ("unet", "forward", "__getstate__", "__setstate__"):
37
- return super().__getattr__(name)
38
- return getattr(self.unet, name)
39
-
40
- def save_pretrained(self, save_directory, **kwargs):
41
- # délègue à la vraie instance UNet2DModel
42
- return self.unet.save_pretrained(save_directory, **kwargs)
43
-
44
- def inference(pipe, img1, img2, num_steps):
45
-
46
  seed = random.randrange(0, 2**32)
47
  torch.manual_seed(seed)
48
-
49
- generator = torch.Generator("cpu").manual_seed(seed)
50
-
51
- img1 = img1.resize((512, 512))
52
- img2 = img2.resize((512, 512))
53
-
54
- img1_np = np.array(img1)
55
- if len(img1_np.shape) > 2:
56
- img1_np = img1_np[:, :, 0]
57
-
58
- img2_np = np.array(img2)
59
- if len(img2_np.shape) > 2:
60
- img2_np = img2_np[:, :, 0]
61
-
62
- img1_np[img1_np > 200] = 255
63
- img1_np[img1_np <= 200] = 0
64
- img1_np = 255-img1_np
65
- img_np = np.stack([img1_np, img2_np, img2_np], axis=2)
66
-
67
- image = PIL.Image.fromarray(img_np)
68
- image = PIL.ImageOps.exif_transpose(image)
69
-
70
- all_images = []
71
-
72
- num_inference_steps = num_steps
73
- image_guidance_scale = 1.9
74
- guidance_scale = 10
75
-
76
- edited_image = pipe(
77
- prompt=[""] ,
78
- image=image,
79
- num_inference_steps=num_inference_steps,
80
- image_guidance_scale=image_guidance_scale,
81
- guidance_scale=guidance_scale,
82
- generator=generator,
83
- safety_checker=None,
84
- num_images_per_prompt=1
85
- ).images
86
-
87
- edited_image = edited_image[0].convert("L")
88
-
89
- return edited_image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
 
91
 
 
13
  from transformers import CLIPTokenizer, CLIPTextModel, CLIPImageProcessor
14
  from PIL import Image
15
  import random
16
+ from contextlib import nullcontext
17
 
18
+ def pil_from(x):
19
+ """Return a PIL.Image given either a PIL.Image or a path string."""
20
+ if isinstance(x, str):
21
+ return PIL.Image.open(x)
22
+ return x
23
+
24
+ def inference(pipe, fiber_imgs, ring_imgs, num_steps):
25
+ """
26
+ fiber_imgs: list/tuple of 4 PIL.Image or paths (order: TL, TR, BL, BR)
27
+ ring_imgs: list/tuple of 4 PIL.Image or paths (same order)
28
+ num_steps: int (num inference steps)
29
+
30
+ returns: list of 4 PIL.Image (L mode), order [TL, TR, BL, BR]
31
+ """
32
+ # seed + generator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  seed = random.randrange(0, 2**32)
34
  torch.manual_seed(seed)
35
+ generator = torch.Generator("cpu").manual_seed(seed)
36
+
37
+ # sizes
38
+ tile = 512
39
+ canvas_size = tile * 2 # 1024
40
+
41
+ # normalize/validate inputs: accept lists or separate args
42
+ if not (isinstance(fiber_imgs, (list, tuple)) and len(fiber_imgs) == 4):
43
+ raise ValueError("fiber_imgs must be a list/tuple of 4 PIL images or file paths.")
44
+ if not (isinstance(ring_imgs, (list, tuple)) and len(ring_imgs) == 4):
45
+ raise ValueError("ring_imgs must be a list/tuple of 4 PIL images or file paths.")
46
+
47
+ # load & preprocess each face
48
+ faces_f = []
49
+ faces_r = []
50
+ for fpath in fiber_imgs:
51
+ im = pil_from(fpath).convert("L").resize((tile, tile), PIL.Image.BILINEAR)
52
+ faces_f.append(im)
53
+ for rpath in ring_imgs:
54
+ im = pil_from(rpath).convert("L").resize((tile, tile), PIL.Image.BILINEAR)
55
+ # binarize like in your old code
56
+ arr = np.array(im)
57
+ arr[arr > 200] = 255
58
+ arr[arr <= 200] = 0
59
+ im_bin = PIL.Image.fromarray(arr.astype(np.uint8))
60
+ faces_r.append(im_bin)
61
+
62
+ # build canvases (L mode)
63
+ canvas_f = PIL.Image.new("L", (canvas_size, canvas_size))
64
+ canvas_r = PIL.Image.new("L", (canvas_size, canvas_size))
65
+
66
+ # paste into corners: order = [TL, TR, BL, BR]
67
+ canvas_f.paste(faces_f[0], (0, 0)) # TL
68
+ canvas_f.paste(faces_f[1], (tile, 0)) # TR
69
+ canvas_f.paste(faces_f[2], (0, tile)) # BL
70
+ canvas_f.paste(faces_f[3], (tile, tile)) # BR
71
+
72
+ canvas_r.paste(faces_r[0], (0, 0))
73
+ canvas_r.paste(faces_r[1], (tile, 0))
74
+ canvas_r.paste(faces_r[2], (0, tile))
75
+ canvas_r.paste(faces_r[3], (tile, tile))
76
+
77
+ # stack channels: [fiber, ring, ring] -> H,W,3
78
+ arr_f = np.array(canvas_f).astype(np.uint8)
79
+ arr_r = np.array(canvas_r).astype(np.uint8)
80
+ arr_in = np.stack([arr_f, arr_r, arr_r], axis=2) # H,W,3
81
+ input_image = PIL.Image.fromarray(arr_in) # PIL RGB
82
+
83
+ # run pipeline (use autocast consistent with device)
84
+ edited_images = []
85
+ if torch.backends.mps.is_available():
86
+ autocast_ctx = nullcontext()
87
+ else:
88
+ autocast_ctx = torch.autocast(torch.device("cuda").type if torch.cuda.is_available() else "cpu")
89
+
90
+ with autocast_ctx:
91
+ out = pipe(
92
+ prompt="", # empty prompt (your model ignores prompt)
93
+ image=input_image,
94
+ num_inference_steps=num_steps,
95
+ image_guidance_scale=1.9,
96
+ guidance_scale=10.0,
97
+ generator=generator,
98
+ safety_checker=None,
99
+ num_images_per_prompt=1,
100
+ )
101
+ # out.images may be a list; take first
102
+ pred = out.images[0]
103
+
104
+ # ensure pred is canvas_size x canvas_size
105
+ if pred.size != (canvas_size, canvas_size):
106
+ pred = pred.resize((canvas_size, canvas_size), PIL.Image.BILINEAR)
107
+
108
+ # split into 4 tiles in same order TL, TR, BL, BR
109
+ tl = pred.crop((0, 0, tile, tile))
110
+ tr = pred.crop((tile, 0, canvas_size, tile))
111
+ bl = pred.crop((0, tile, tile, canvas_size))
112
+ br = pred.crop((tile, tile, canvas_size, canvas_size))
113
+
114
+ # close opened images to free handles
115
+ for im in faces_f + faces_r:
116
+ try:
117
+ im.close()
118
+ except Exception:
119
+ pass
120
+ try:
121
+ canvas_f.close(); canvas_r.close()
122
+ except Exception:
123
+ pass
124
+
125
+ return [tl, tr, bl, br]
126
 
127