Spaces:
Sleeping
Sleeping
Update shadow_generator.py
Browse files- shadow_generator.py +8 -35
shadow_generator.py
CHANGED
|
@@ -229,7 +229,6 @@ from PIL import Image, ImageFilter, ImageChops
|
|
| 229 |
import torch
|
| 230 |
from torchvision import transforms as T
|
| 231 |
import numpy as np
|
| 232 |
-
import cv2
|
| 233 |
|
| 234 |
# -----------------------------
|
| 235 |
# Paths & constants
|
|
@@ -345,8 +344,7 @@ class SSNWrapper:
|
|
| 345 |
try:
|
| 346 |
from SSN_Model import SSN_Model
|
| 347 |
self.model = SSN_Model()
|
| 348 |
-
|
| 349 |
-
state = torch.load(str(WEIGHT_FILE), map_location=self.device, weights_only=False)
|
| 350 |
if isinstance(state, dict):
|
| 351 |
if "model_state_dict" in state:
|
| 352 |
sd = state["model_state_dict"]
|
|
@@ -407,7 +405,7 @@ def load_ssn_once() -> SSNWrapper:
|
|
| 407 |
return _ssn_wrapper
|
| 408 |
|
| 409 |
# -----------------------------
|
| 410 |
-
# Shadow compositing (
|
| 411 |
# -----------------------------
|
| 412 |
def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
|
| 413 |
c = color.strip().lstrip("#")
|
|
@@ -417,32 +415,15 @@ def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
|
|
| 417 |
raise ValueError("Invalid color hex.")
|
| 418 |
return tuple(int(c[i:i+2],16) for i in (0,2,4))
|
| 419 |
|
| 420 |
-
def _warp_shadow(matte_img: Image.Image, w: int, h: int, direction: float, distance: float) -> Image.Image:
|
| 421 |
-
matte = np.array(matte_img.resize((w, h)).convert("L"))
|
| 422 |
-
|
| 423 |
-
angle = np.deg2rad(direction)
|
| 424 |
-
dx = int(np.cos(angle) * distance)
|
| 425 |
-
dy = int(np.sin(angle) * distance)
|
| 426 |
-
|
| 427 |
-
src = np.float32([[0,0],[w,0],[0,h],[w,h]])
|
| 428 |
-
dst = np.float32([[0,0],[w,0],[0+dx,h],[w+dx,h+dy]])
|
| 429 |
-
|
| 430 |
-
M = cv2.getPerspectiveTransform(src, dst)
|
| 431 |
-
warped = cv2.warpPerspective(matte, M, (w,h), borderValue=0)
|
| 432 |
-
|
| 433 |
-
return Image.fromarray(warped)
|
| 434 |
-
|
| 435 |
def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image:
|
| 436 |
-
_log("Compositing shadow and image (
|
| 437 |
w, h = original_img.size
|
| 438 |
r, g, b = _hex_to_rgb(params.get("color", "#000000"))
|
| 439 |
|
| 440 |
-
|
| 441 |
-
obj_alpha = original_img.getchannel("A")
|
| 442 |
-
matte_clean = ImageChops.multiply(matte_warped, obj_alpha)
|
| 443 |
|
| 444 |
shadow_rgba = Image.new("RGBA", (w, h), (r, g, b, 0))
|
| 445 |
-
alpha =
|
| 446 |
shadow_rgba.putalpha(alpha)
|
| 447 |
|
| 448 |
softness = max(0.0, float(params.get("softness", 20.0)))
|
|
@@ -453,22 +434,14 @@ def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Imag
|
|
| 453 |
final_rgba.alpha_composite(shadow_rgba)
|
| 454 |
final_rgba.alpha_composite(original_img)
|
| 455 |
|
| 456 |
-
|
| 457 |
-
if params.get("background", "transparent") != "transparent":
|
| 458 |
-
bg_color = params.get("background", "#FFFFFF")
|
| 459 |
-
br, bgc, bb = _hex_to_rgb(bg_color)
|
| 460 |
-
flattened = Image.new("RGB", (w, h), (br, bgc, bb))
|
| 461 |
-
flattened.paste(final_rgba, mask=final_rgba.getchannel("A"))
|
| 462 |
-
return flattened.convert("RGBA")
|
| 463 |
-
|
| 464 |
-
_log("Composition complete with anchored shadow.")
|
| 465 |
return final_rgba
|
| 466 |
|
| 467 |
# -----------------------------
|
| 468 |
# Public API
|
| 469 |
# -----------------------------
|
| 470 |
def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 471 |
-
defaults = dict(softness=20.0, opacity=0.7, color="#000000", direction=45.0, distance=80.0
|
| 472 |
merged = {**defaults, **(params or {})}
|
| 473 |
merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"])))
|
| 474 |
merged["softness"] = max(0.0, float(merged["softness"]))
|
|
@@ -487,10 +460,10 @@ def generate_shadow_rgba(rgba_file_bytes: bytes, params: Optional[Dict[str, Any]
|
|
| 487 |
raise RuntimeError("SSN model not available.")
|
| 488 |
|
| 489 |
matte_img = ssn.infer_shadow_matte(img, params["direction"])
|
| 490 |
-
|
| 491 |
if matte_img is None:
|
| 492 |
_log("❌ Failed to generate shadow matte")
|
| 493 |
raise RuntimeError("Failed to generate shadow matte.")
|
|
|
|
| 494 |
final_img = _composite_shadow_and_image(img, matte_img, params)
|
| 495 |
buf = io.BytesIO()
|
| 496 |
final_img.save(buf, format="PNG")
|
|
|
|
| 229 |
import torch
|
| 230 |
from torchvision import transforms as T
|
| 231 |
import numpy as np
|
|
|
|
| 232 |
|
| 233 |
# -----------------------------
|
| 234 |
# Paths & constants
|
|
|
|
| 344 |
try:
|
| 345 |
from SSN_Model import SSN_Model
|
| 346 |
self.model = SSN_Model()
|
| 347 |
+
state = torch.load(str(WEIGHT_FILE), map_location=self.device)
|
|
|
|
| 348 |
if isinstance(state, dict):
|
| 349 |
if "model_state_dict" in state:
|
| 350 |
sd = state["model_state_dict"]
|
|
|
|
| 405 |
return _ssn_wrapper
|
| 406 |
|
| 407 |
# -----------------------------
|
| 408 |
+
# Shadow compositing (REALISTIC SSN)
|
| 409 |
# -----------------------------
|
| 410 |
def _hex_to_rgb(color: str) -> Tuple[int, int, int]:
|
| 411 |
c = color.strip().lstrip("#")
|
|
|
|
| 415 |
raise ValueError("Invalid color hex.")
|
| 416 |
return tuple(int(c[i:i+2],16) for i in (0,2,4))
|
| 417 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
def _composite_shadow_and_image(original_img: Image.Image, matte_img: Image.Image, params: Dict[str, Any]) -> Image.Image:
|
| 419 |
+
_log("Compositing shadow and image (REAL SSN)...")
|
| 420 |
w, h = original_img.size
|
| 421 |
r, g, b = _hex_to_rgb(params.get("color", "#000000"))
|
| 422 |
|
| 423 |
+
matte_resized = matte_img.resize((w, h)).convert("L")
|
|
|
|
|
|
|
| 424 |
|
| 425 |
shadow_rgba = Image.new("RGBA", (w, h), (r, g, b, 0))
|
| 426 |
+
alpha = matte_resized.point(lambda p: int(p * float(params.get("opacity", 0.7))))
|
| 427 |
shadow_rgba.putalpha(alpha)
|
| 428 |
|
| 429 |
softness = max(0.0, float(params.get("softness", 20.0)))
|
|
|
|
| 434 |
final_rgba.alpha_composite(shadow_rgba)
|
| 435 |
final_rgba.alpha_composite(original_img)
|
| 436 |
|
| 437 |
+
_log("✅ Composition complete with realistic SSN shadow.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
return final_rgba
|
| 439 |
|
| 440 |
# -----------------------------
|
| 441 |
# Public API
|
| 442 |
# -----------------------------
|
| 443 |
def _apply_params_defaults(params: Optional[Dict[str, Any]]) -> Dict[str, Any]:
|
| 444 |
+
defaults = dict(softness=20.0, opacity=0.7, color="#000000", direction=45.0, distance=80.0)
|
| 445 |
merged = {**defaults, **(params or {})}
|
| 446 |
merged["opacity"] = max(0.0, min(1.0, float(merged["opacity"])))
|
| 447 |
merged["softness"] = max(0.0, float(merged["softness"]))
|
|
|
|
| 460 |
raise RuntimeError("SSN model not available.")
|
| 461 |
|
| 462 |
matte_img = ssn.infer_shadow_matte(img, params["direction"])
|
|
|
|
| 463 |
if matte_img is None:
|
| 464 |
_log("❌ Failed to generate shadow matte")
|
| 465 |
raise RuntimeError("Failed to generate shadow matte.")
|
| 466 |
+
|
| 467 |
final_img = _composite_shadow_and_image(img, matte_img, params)
|
| 468 |
buf = io.BytesIO()
|
| 469 |
final_img.save(buf, format="PNG")
|