karthikeya1212 commited on
Commit
7bd9d8c
·
verified ·
1 Parent(s): 202da62

Update shadow_generator.py

Browse files
Files changed (1) hide show
  1. shadow_generator.py +8 -35
shadow_generator.py CHANGED
@@ -229,7 +229,6 @@ from PIL import Image, ImageFilter, ImageChops
229
  import torch
230
  from torchvision import transforms as T
231
  import numpy as np
232
- import cv2
233
 
234
  # -----------------------------
235
  # Paths & constants
@@ -345,8 +344,7 @@ class SSNWrapper:
345
  try:
346
  from SSN_Model import SSN_Model
347
  self.model = SSN_Model()
348
- with torch.serialization.safe_globals([np.core.multiarray.scalar]):
349
- state = torch.load(str(WEIGHT_FILE), map_location=self.device, weights_only=False)
350
  if isinstance(state, dict):
351
  if "model_state_dict" in state:
352
  sd = state["model_state_dict"]
@@ -407,7 +405,7 @@ def load_ssn_once() -> SSNWrapper:
407
  return _ssn_wrapper
408
 
409
  # -----------------------------
410
- # Shadow compositing (Attached Shadow)
411
  # -----------------------------
412
  def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
413
  c = color.strip().lstrip("#")
@@ -417,32 +415,15 @@ def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
417
  raise ValueError("Invalid color hex.")
418
  return tuple(int(c[i:i+2],16) for i in (0,2,4))
419
 
420
- def _warp_shadow(matte_img: Image.Image, w: int, h: int, direction: float, distance: float) -> Image.Image:
421
- matte = np.array(matte_img.resize((w, h)).convert("L"))
422
-
423
- angle = np.deg2rad(direction)
424
- dx = int(np.cos(angle) * distance)
425
- dy = int(np.sin(angle) * distance)
426
-
427
- src = np.float32([[0,0],[w,0],[0,h],[w,h]])
428
- dst = np.float32([[0,0],[w,0],[0+dx,h],[w+dx,h+dy]])
429
-
430
- M = cv2.getPerspectiveTransform(src, dst)
431
- warped = cv2.warpPerspective(matte, M, (w,h), borderValue=0)
432
-
433
- return Image.fromarray(warped)
434
-
435
  def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image:
436
- _log("Compositing shadow and image (anchored)...")
437
  w, h = original_img.size
438
  r, g, b = _hex_to_rgb(params.get("color", "#000000"))
439
 
440
- matte_warped = _warp_shadow(matte_img, w, h, params["direction"], params["distance"])
441
- obj_alpha = original_img.getchannel("A")
442
- matte_clean = ImageChops.multiply(matte_warped, obj_alpha)
443
 
444
  shadow_rgba = Image.new("RGBA", (w, h), (r, g, b, 0))
445
- alpha = matte_clean.point(lambda p: int(p * float(params.get("opacity", 0.7))))
446
  shadow_rgba.putalpha(alpha)
447
 
448
  softness = max(0.0, float(params.get("softness", 20.0)))
@@ -453,22 +434,14 @@ def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Imag
453
  final_rgba.alpha_composite(shadow_rgba)
454
  final_rgba.alpha_composite(original_img)
455
 
456
- # --- NEW: Option to flatten background ---
457
- if params.get("background", "transparent") != "transparent":
458
- bg_color = params.get("background", "#FFFFFF")
459
- br, bgc, bb = _hex_to_rgb(bg_color)
460
- flattened = Image.new("RGB", (w, h), (br, bgc, bb))
461
- flattened.paste(final_rgba, mask=final_rgba.getchannel("A"))
462
- return flattened.convert("RGBA")
463
-
464
- _log("Composition complete with anchored shadow.")
465
  return final_rgba
466
 
467
  # -----------------------------
468
  # Public API
469
  # -----------------------------
470
  def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
471
- defaults = dict(softness=20.0, opacity=0.7, color="#000000", direction=45.0, distance=80.0, background="transparent")
472
  merged = {**defaults, **(params or {})}
473
  merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"])))
474
  merged["softness"] = max(0.0, float(merged["softness"]))
@@ -487,10 +460,10 @@ def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]
487
  raise RuntimeError("SSN model not available.")
488
 
489
  matte_img = ssn.infer_shadow_matte(img, params["direction"])
490
-
491
  if matte_img is None:
492
  _log("❌ Failed to generate shadow matte")
493
  raise RuntimeError("Failed to generate shadow matte.")
 
494
  final_img = _composite_shadow_and_image(img, matte_img, params)
495
  buf = io.BytesIO()
496
  final_img.save(buf, format="PNG")
 
229
  import torch
230
  from torchvision import transforms as T
231
  import numpy as np
 
232
 
233
  # -----------------------------
234
  # Paths & constants
 
344
  try:
345
  from SSN_Model import SSN_Model
346
  self.model = SSN_Model()
347
+ state = torch.load(str(WEIGHT_FILE), map_location=self.device)
 
348
  if isinstance(state, dict):
349
  if "model_state_dict" in state:
350
  sd = state["model_state_dict"]
 
405
  return _ssn_wrapper
406
 
407
  # -----------------------------
408
+ # Shadow compositing (REALISTIC SSN)
409
  # -----------------------------
410
  def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
411
  c = color.strip().lstrip("#")
 
415
  raise ValueError("Invalid color hex.")
416
  return tuple(int(c[i:i+2],16) for i in (0,2,4))
417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
418
  def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image:
419
+ _log("Compositing shadow and image (REAL SSN)...")
420
  w, h = original_img.size
421
  r, g, b = _hex_to_rgb(params.get("color", "#000000"))
422
 
423
+ matte_resized = matte_img.resize((w, h)).convert("L")
 
 
424
 
425
  shadow_rgba = Image.new("RGBA", (w, h), (r, g, b, 0))
426
+ alpha = matte_resized.point(lambda p: int(p * float(params.get("opacity", 0.7))))
427
  shadow_rgba.putalpha(alpha)
428
 
429
  softness = max(0.0, float(params.get("softness", 20.0)))
 
434
  final_rgba.alpha_composite(shadow_rgba)
435
  final_rgba.alpha_composite(original_img)
436
 
437
+ _log("✅ Composition complete with realistic SSN shadow.")
 
 
 
 
 
 
 
 
438
  return final_rgba
439
 
440
  # -----------------------------
441
  # Public API
442
  # -----------------------------
443
  def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
444
+ defaults = dict(softness=20.0, opacity=0.7, color="#000000", direction=45.0, distance=80.0)
445
  merged = {**defaults, **(params or {})}
446
  merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"])))
447
  merged["softness"] = max(0.0, float(merged["softness"]))
 
460
  raise RuntimeError("SSN model not available.")
461
 
462
  matte_img = ssn.infer_shadow_matte(img, params["direction"])
 
463
  if matte_img is None:
464
  _log("❌ Failed to generate shadow matte")
465
  raise RuntimeError("Failed to generate shadow matte.")
466
+
467
  final_img = _composite_shadow_and_image(img, matte_img, params)
468
  buf = io.BytesIO()
469
  final_img.save(buf, format="PNG")