cesar-tek commited on
Commit
e242eca
·
verified ·
1 Parent(s): 709dd18

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +49 -23
app.py CHANGED
@@ -41,27 +41,56 @@ pipe = pipe.to(device)
41
  print("Model loaded successfully!")
42
 
43
 
44
- def preprocess_image(image: Image.Image) -> tuple[Image.Image, tuple[int, int]]:
45
- """Pad image to multiple of 64 for SD3 compatibility."""
 
 
 
 
 
46
  original_size = image.size
47
  w, h = original_size
48
 
49
- # SD3 works best with dimensions that are multiples of 64
50
- new_w = (w + 63) // 64 * 64
51
- new_h = (h + 63) // 64 * 64
 
 
 
 
 
 
52
 
53
- if (new_w, new_h) != (w, h):
54
- padded_img = Image.new('RGB', (new_w, new_h), (0, 0, 0))
 
 
 
 
55
  padded_img.paste(image, (0, 0))
56
- return padded_img, original_size
57
 
58
- return image, original_size
59
 
60
 
61
- def postprocess_image(image: Image.Image, original_size: tuple[int, int]) -> Image.Image:
62
- """Crop image back to original size."""
63
- if image.size != original_size:
64
- return image.crop((0, 0, original_size[0], original_size[1]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  return image
66
 
67
 
@@ -104,12 +133,12 @@ def remove_watermark(
104
 
105
  generator = torch.Generator(device=device).manual_seed(seed)
106
 
107
- # Preprocess image - pad to multiple of 64
108
- processed_image, _ = preprocess_image(input_image)
109
  padded_w, padded_h = processed_image.size
110
- print(f"Padded image size: {padded_w}x{padded_h}")
111
 
112
- # Run regeneration with explicit dimensions
113
  result = pipe(
114
  prompt="", # Empty prompt for pure regeneration
115
  image=processed_image,
@@ -117,16 +146,13 @@ def remove_watermark(
117
  num_inference_steps=num_inference_steps,
118
  guidance_scale=0.0, # No guidance for pure regeneration
119
  generator=generator,
120
- width=padded_w,
121
- height=padded_h,
122
  ).images[0]
123
 
124
  print(f"Pipeline output size: {result.size}")
125
 
126
- # Crop back to ORIGINAL size (not padded size)
127
- if result.size != (original_w, original_h):
128
- result = result.crop((0, 0, original_w, original_h))
129
- print(f"Cropped to original size: {result.size}")
130
 
131
  return result, seed
132
 
 
41
  print("Model loaded successfully!")
42
 
43
 
44
+ def preprocess_image(image: Image.Image, max_size: int = 1536) -> tuple[Image.Image, tuple[int, int], float]:
45
+ """
46
+ Resize large images and pad to multiple of 64 for SD3 compatibility.
47
+
48
+ Returns:
49
+ Tuple of (processed_image, original_size, scale_factor)
50
+ """
51
  original_size = image.size
52
  w, h = original_size
53
 
54
+ # Calculate scale factor if image is too large
55
+ scale_factor = 1.0
56
+ if max(w, h) > max_size:
57
+ scale_factor = max_size / max(w, h)
58
+ new_w = int(w * scale_factor)
59
+ new_h = int(h * scale_factor)
60
+ image = image.resize((new_w, new_h), Image.Resampling.LANCZOS)
61
+ print(f"Resized from {w}x{h} to {new_w}x{new_h} (scale: {scale_factor:.3f})")
62
+ w, h = new_w, new_h
63
 
64
+ # Pad to multiple of 64
65
+ pad_w = (w + 63) // 64 * 64
66
+ pad_h = (h + 63) // 64 * 64
67
+
68
+ if (pad_w, pad_h) != (w, h):
69
+ padded_img = Image.new('RGB', (pad_w, pad_h), (0, 0, 0))
70
  padded_img.paste(image, (0, 0))
71
+ return padded_img, original_size, scale_factor
72
 
73
+ return image, original_size, scale_factor
74
 
75
 
76
+ def postprocess_image(image: Image.Image, original_size: tuple[int, int], scale_factor: float) -> Image.Image:
77
+ """Crop padding and resize back to original dimensions."""
78
+ w, h = image.size
79
+ original_w, original_h = original_size
80
+
81
+ # First crop to the scaled size (remove padding)
82
+ if scale_factor < 1.0:
83
+ scaled_w = int(original_w * scale_factor)
84
+ scaled_h = int(original_h * scale_factor)
85
+ image = image.crop((0, 0, scaled_w, scaled_h))
86
+ # Then resize back to original
87
+ image = image.resize((original_w, original_h), Image.Resampling.LANCZOS)
88
+ print(f"Upscaled back to original size: {original_w}x{original_h}")
89
+ else:
90
+ # Just crop to original size
91
+ if image.size != original_size:
92
+ image = image.crop((0, 0, original_w, original_h))
93
+
94
  return image
95
 
96
 
 
133
 
134
  generator = torch.Generator(device=device).manual_seed(seed)
135
 
136
+ # Preprocess image - resize if too large and pad to multiple of 64
137
+ processed_image, original_size, scale_factor = preprocess_image(input_image)
138
  padded_w, padded_h = processed_image.size
139
+ print(f"Processed image size: {padded_w}x{padded_h} (scale: {scale_factor:.3f})")
140
 
141
+ # Run regeneration
142
  result = pipe(
143
  prompt="", # Empty prompt for pure regeneration
144
  image=processed_image,
 
146
  num_inference_steps=num_inference_steps,
147
  guidance_scale=0.0, # No guidance for pure regeneration
148
  generator=generator,
 
 
149
  ).images[0]
150
 
151
  print(f"Pipeline output size: {result.size}")
152
 
153
+ # Postprocess - crop padding and resize back to original
154
+ result = postprocess_image(result, original_size, scale_factor)
155
+ print(f"Final output size: {result.size}")
 
156
 
157
  return result, seed
158