burman-ai commited on
Commit
e7bcb12
·
verified ·
1 Parent(s): b87511a

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +31 -0
  2. main_code_script.py +108 -0
app.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ import os
4
+ # Import your core functions (estimate_pose, segment_clothing, inpaint_clothing, change_clothing) from your main script
5
+ from main_code_script import change_clothing # Replace your_main_script
6
+ def predict(image_path, garment_image_path): # Changed input
7
+ """
8
+ The prediction function for Gradio.
9
+ """
10
+ try:
11
+ modified_image = change_clothing(image_path, garment_image_path) # Changed input
12
+ if modified_image:
13
+ return modified_image
14
+ else:
15
+ return "Failed to change clothing. Please check the images."
16
+ except Exception as e:
17
+ return f"Error: {e}"
18
+ # Create the Gradio interface
19
+ iface = gr.Interface(
20
+ fn=predict,
21
+ inputs=[
22
+ gr.Image(type="filepath", label="Input Image (Person)"), # Changed label
23
+ gr.Image(type="filepath", label="Garment Image"), # Added input
24
+ ],
25
+ outputs=gr.Image(type="pil", label="Modified Image"),
26
+ title="AI Clothing Changer",
27
+ description="Try on different clothes with AI by uploading a garment image!", # Changed description
28
+ )
29
+ # Launch the Gradio interface
30
+ if __name__ == "__main__":
31
+ iface.launch()
main_code_script.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Install necessary libraries (in your requirements.txt)
2
+ # pillow opencv-python transformers mediapipe diffusers accelerate transformers
3
+ # Example install command: pip install pillow opencv-python transformers mediapipe diffusers accelerate transformers
4
+ from PIL import Image
5
+ import cv2
6
+ import mediapipe as mp
7
+ import numpy as np
8
+ from transformers import pipeline
9
+ from diffusers import StableDiffusionInpaintPipeline
10
+ import torch
11
+ # --- 1. Pose Estimation (using Mediapipe) ---
12
+ def estimate_pose(image_path):
13
+ """Detects the pose of a person in an image using Mediapipe.
14
+ Args:
15
+ image_path: Path to the input image.
16
+ Returns:
17
+ A list of landmarks (x, y, visibility)
18
+ or None if no pose is detected.
19
+ """
20
+ mp_drawing = mp.solutions.drawing_utils
21
+ mp_pose = mp.solutions.pose
22
+ with mp_pose.Pose(
23
+ static_image_mode=True,
24
+ model_complexity=2,
25
+ enable_segmentation=True,
26
+ min_detection_confidence=0.5) as pose:
27
+ image = cv2.imread(image_path)
28
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
29
+ results = pose.process(image_rgb)
30
+
31
+ if results.pose_landmarks:
32
+ # Example: Draw the pose landmarks on the image (for visualization)
33
+ annotated_image = image.copy()
34
+ mp_drawing.draw_landmarks(
35
+ annotated_image, results.pose_landmarks, mp_pose.POSE_CONNECTIONS)
36
+ #cv2.imwrite("pose_annotated.jpg", annotated_image) # Save annotated image
37
+ #return results.pose_landmarks.landmark
38
+ # Return the landmarks
39
+ return results, image # Return the entire result
40
+ else:
41
+ return None, None # or raise an exception
42
+ # --- 2. Clothing Segmentation (Example - using a placeholder function) ---
43
+ def segment_clothing(image, results): #Added result
44
+ """Segments the clothing region in the image.
45
+ This is a simplified example. In reality, you would use a pre-trained
46
+ segmentation model.
47
+ """
48
+ # 1. Create a mask where the person is present.
49
+ segmentation_mask = results.segmentation_mask
50
+ threshold = 0.5 # Adjust this threshold as needed.
51
+ # Threshold the segmentation mask to create a binary mask.
52
+ binary_mask = (segmentation_mask > threshold).astype(np.uint8) * 255
53
+ # Convert binary mask to a PIL Image
54
+ mask_img = Image.fromarray(binary_mask).convert("L")
55
+ return mask_img
56
+ # --- 3. Image Inpainting (Replacing Clothing - using Stable Diffusion Inpainting) ---
57
+ def inpaint_clothing(image, mask_img, garment_image_path, device="cuda" if torch.cuda.is_available() else "cpu"): # Changed input
58
+ """
59
+ Replaces the clothing region in the image with the uploaded garment image,
60
+ using Stable Diffusion Inpainting.
61
+ """
62
+ pipe = StableDiffusionInpaintPipeline.from_pretrained(
63
+ "stabilityai/stable-diffusion-2-inpainting",
64
+ torch_dtype=torch.float16
65
+ )
66
+ pipe = pipe.to(device)
67
+ # Resize the image and mask to the same size. Important for inpainting.
68
+ image = image.resize((512, 512))
69
+ mask_img = mask_img.resize((512, 512))
70
+
71
+ # Load the garment image
72
+ garment_image = Image.open(garment_image_path).convert("RGB")
73
+ garment_image = garment_image.resize((512,512)) # Resize if necessary
74
+
75
+ # Inpaint using the garment image as a guide (This part might need further refinement)
76
+ # A simple approach is to use the garment image in the prompt.
77
+ # More advanced techniques might involve using the garment image as
78
+ # a style reference or directly manipulating the latent space.
79
+ prompt = f"A photo of a person wearing the uploaded garment"
80
+ image = pipe(prompt=prompt, image=image, mask_image=mask_img).images[0]
81
+ return image
82
+ # --- 4. Main Function (Putting it all together) ---
83
+ def change_clothing(image_path, garment_image_path): # Changed input
84
+ """
85
+ Main function to change the clothing in an image.
86
+ """
87
+ # 1. Load the image
88
+ image = Image.open(image_path).convert("RGB")
89
+ # 2. Estimate the pose
90
+ results, cv2_image = estimate_pose(image_path)
91
+ if results is None:
92
+ print("No pose detected.")
93
+ return None
94
+ # 3. Segment the clothing
95
+ mask_img = segment_clothing(image, results)
96
+ # 4. Inpaint the clothing
97
+ modified_image = inpaint_clothing(image, mask_img, garment_image_path) # Changed input
98
+ return modified_image
99
+ # --- Example Usage ---
100
+ if __name__ == "__main__":
101
+ input_image_path = "person.jpg" # Replace with your image
102
+ garment_image_path = "garment.jpg" # Replace with your garment image
103
+ modified_image = change_clothing(input_image_path, garment_image_path)
104
+ if modified_image:
105
+ modified_image.save("modified_image.jpg")
106
+ print("Clothing changed and saved to modified_image.jpg")
107
+ else:
108
+ print("Failed to change clothing.")