File size: 4,793 Bytes
e7bcb12
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# Install necessary libraries (in your requirements.txt)
# pillow opencv-python transformers mediapipe diffusers accelerate transformers
# Example install command: pip install pillow opencv-python transformers mediapipe diffusers accelerate transformers
from PIL import Image
import cv2
import mediapipe as mp
import numpy as np
from transformers import pipeline
from diffusers import StableDiffusionInpaintPipeline
import torch
# --- 1. Pose Estimation (using Mediapipe) ---
def estimate_pose(image_path):
    """Detects the pose of a person in an image using Mediapipe.
    Args:
        image_path: Path to the input image.
    Returns:
        A list of landmarks (x, y, visibility) 
        or None if no pose is detected.
    """
    mp_drawing = mp.solutions.drawing_utils
    mp_pose = mp.solutions.pose
    with mp_pose.Pose(
        static_image_mode=True,
        model_complexity=2,
        enable_segmentation=True,
        min_detection_confidence=0.5) as pose:
        image = cv2.imread(image_path)
        image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        results = pose.process(image_rgb)
        
    if results.pose_landmarks:
        # Example: Draw the pose landmarks on the image (for visualization)
        annotated_image = image.copy()
        mp_drawing.draw_landmarks(
            annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
        #cv2.imwrite("pose_annotated.jpg", annotated_image)  # Save annotated image
        #return results.pose_landmarks.landmark  
        # Return the landmarks
        return results, image # Return the entire result
    else:
        return None, None # or raise an exception
# --- 2. Clothing Segmentation (Example - using a placeholder function) ---
def segment_clothing(image, results): #Added result
    """Segments the clothing region in the image.
    This is a simplified example. In reality, you would use a pre-trained
       segmentation model.
    """
    # 1. Create a mask where the person is present.
    segmentation_mask = results.segmentation_mask
    threshold = 0.5  # Adjust this threshold as needed.
    # Threshold the segmentation mask to create a binary mask.
    binary_mask = (segmentation_mask > threshold).astype(np.uint8) * 255
    # Convert binary mask to a PIL Image
    mask_img = Image.fromarray(binary_mask).convert("L")
    return mask_img
# --- 3. Image Inpainting (Replacing Clothing - using Stable Diffusion Inpainting) ---
def inpaint_clothing(image, mask_img, garment_image_path, device="cuda" if torch.cuda.is_available() else "cpu"): # Changed input
    """
    Replaces the clothing region in the image with the uploaded garment image,
    using Stable Diffusion Inpainting.
    """
    pipe = StableDiffusionInpaintPipeline.from_pretrained(
        "stabilityai/stable-diffusion-2-inpainting",
        torch_dtype=torch.float16
    )
    pipe = pipe.to(device)
    # Resize the image and mask to the same size.  Important for inpainting.
    image = image.resize((512, 512))
    mask_img = mask_img.resize((512, 512))
    
    # Load the garment image
    garment_image = Image.open(garment_image_path).convert("RGB")
    garment_image = garment_image.resize((512,512)) # Resize if necessary
    
    # Inpaint using the garment image as a guide (This part might need further refinement)
    # A simple approach is to use the garment image in the prompt.  
    # More advanced techniques might involve using the garment image as 
    # a style reference or directly manipulating the latent space.
    prompt = f"A photo of a person wearing the uploaded garment" 
    image = pipe(prompt=prompt, image=image, mask_image=mask_img).images[0]
    return image
# --- 4. Main Function (Putting it all together) ---
def change_clothing(image_path, garment_image_path): # Changed input
    """
    Main function to change the clothing in an image.
    """
    # 1. Load the image
    image = Image.open(image_path).convert("RGB")
    # 2. Estimate the pose
    results, cv2_image = estimate_pose(image_path)
    if results is None:
        print("No pose detected.")
        return None
    # 3. Segment the clothing
    mask_img = segment_clothing(image, results)
    # 4. Inpaint the clothing
    modified_image = inpaint_clothing(image, mask_img, garment_image_path) # Changed input
    return modified_image
# --- Example Usage ---
if __name__ == "__main__":
    input_image_path = "person.jpg"  # Replace with your image
    garment_image_path = "garment.jpg" # Replace with your garment image
    modified_image = change_clothing(input_image_path, garment_image_path)
    if modified_image:
        modified_image.save("modified_image.jpg")
        print("Clothing changed and saved to modified_image.jpg")
    else:
        print("Failed to change clothing.")