Spaces:
Sleeping
Sleeping
Update main.py
Browse files
main.py
CHANGED
|
@@ -89,47 +89,69 @@ def create_the_final_results(fabric: Image.Image, person: Image.Image, mask: Ima
|
|
| 89 |
return results
|
| 90 |
|
| 91 |
def load_image_from_base64(s: str, m: str = 'RGB'):
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 95 |
|
| 96 |
# === API ENDPOINTS (CORRECT AND GUARANTEED) ===
|
| 97 |
|
| 98 |
@app.get("/")
|
| 99 |
-
def root():
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
@app.post("/generate")
|
| 103 |
async def api_generate(request: Request, inputs: ApiInput):
|
| 104 |
-
|
| 105 |
load_model()
|
|
|
|
| 106 |
API_KEY = os.environ.get("API_KEY")
|
| 107 |
-
if request.headers.get("x-api-key") != API_KEY:
|
| 108 |
-
|
| 109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
#
|
| 112 |
TARGET_SIZE = (1024, 1024)
|
| 113 |
person_resized = person.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
|
| 114 |
|
|
|
|
| 115 |
if inputs.mask_base64:
|
| 116 |
-
print( "api_generate if start processing.." );
|
| 117 |
mask = load_image_from_base64(inputs.mask_base64, mode='L')
|
| 118 |
-
if mask is None:
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
else:
|
|
|
|
| 122 |
mask = generate_ultimate_mask(person_resized)
|
| 123 |
-
|
| 124 |
-
|
| 125 |
final_results = create_the_final_results(fabric, person_resized, mask)
|
| 126 |
|
|
|
|
| 127 |
def to_base64(img):
|
| 128 |
-
print( "to_base64 processing.." );
|
| 129 |
img_display = img.resize((512, 512), Image.Resampling.LANCZOS)
|
| 130 |
-
buf = io.BytesIO()
|
|
|
|
| 131 |
return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}"
|
| 132 |
|
|
|
|
| 133 |
response_data = {
|
| 134 |
'ultimate_image': to_base64(final_results['ultimate_image']),
|
| 135 |
'fine_weave_image': to_base64(final_results['fine_weave_image']),
|
|
@@ -137,5 +159,5 @@ async def api_generate(request: Request, inputs: ApiInput):
|
|
| 137 |
'creative_variation_image': to_base64(final_results['creative_variation_image']),
|
| 138 |
'mask_image': to_base64(mask)
|
| 139 |
}
|
| 140 |
-
|
| 141 |
return response_data
|
|
|
|
| 89 |
return results
|
| 90 |
|
| 91 |
def load_image_from_base64(s: str, m: str = 'RGB'):
|
| 92 |
+
"""Decodes a Base64 string and opens it as a PIL Image."""
|
| 93 |
+
if "," not in s:
|
| 94 |
+
return None
|
| 95 |
+
try:
|
| 96 |
+
img_data = base64.b64decode(s.split(",")[1])
|
| 97 |
+
img = Image.open(io.BytesIO(img_data))
|
| 98 |
+
return img.convert(m)
|
| 99 |
+
except Exception as e:
|
| 100 |
+
print(f"Error loading image from base64: {e}")
|
| 101 |
+
return None
|
| 102 |
|
| 103 |
# === API ENDPOINTS (CORRECT AND GUARANTEED) ===
|
| 104 |
|
| 105 |
@app.get("/")
|
| 106 |
+
def root():
|
| 107 |
+
return {"status": "API server is running. Model will load on first call."}
|
| 108 |
+
|
| 109 |
+
class ApiInput(BaseModel):
|
| 110 |
+
person_base64: str
|
| 111 |
+
fabric_base64: str
|
| 112 |
+
mask_base64: Optional[str] = None
|
| 113 |
|
| 114 |
@app.post("/generate")
|
| 115 |
async def api_generate(request: Request, inputs: ApiInput):
|
| 116 |
+
# Ensure the model is loaded
|
| 117 |
load_model()
|
| 118 |
+
|
| 119 |
API_KEY = os.environ.get("API_KEY")
|
| 120 |
+
if request.headers.get("x-api-key") != API_KEY:
|
| 121 |
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
| 122 |
+
|
| 123 |
+
# Load person and fabric images from base64
|
| 124 |
+
person = load_image_from_base64(inputs.person_base64)
|
| 125 |
+
fabric = load_image_from_base64(inputs.fabric_base64)
|
| 126 |
+
|
| 127 |
+
if person is None or fabric is None:
|
| 128 |
+
raise HTTPException(status_code=400, detail="Could not decode base64 images.")
|
| 129 |
|
| 130 |
+
# Resize person image to a standard size
|
| 131 |
TARGET_SIZE = (1024, 1024)
|
| 132 |
person_resized = person.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
|
| 133 |
|
| 134 |
+
# Handle mask image if provided
|
| 135 |
if inputs.mask_base64:
|
|
|
|
| 136 |
mask = load_image_from_base64(inputs.mask_base64, mode='L')
|
| 137 |
+
if mask is None:
|
| 138 |
+
raise HTTPException(status_code=400, detail="Could not decode mask base64.")
|
| 139 |
+
mask = mask.resize(TARGET_SIZE, Image.Resampling.LANCZOS)
|
| 140 |
+
else:
|
| 141 |
+
# If no mask is provided, generate one
|
| 142 |
mask = generate_ultimate_mask(person_resized)
|
| 143 |
+
|
| 144 |
+
# Process and create the final results
|
| 145 |
final_results = create_the_final_results(fabric, person_resized, mask)
|
| 146 |
|
| 147 |
+
# Convert image to base64 for the response
|
| 148 |
def to_base64(img):
|
|
|
|
| 149 |
img_display = img.resize((512, 512), Image.Resampling.LANCZOS)
|
| 150 |
+
buf = io.BytesIO()
|
| 151 |
+
img_display.save(buf, format="PNG")
|
| 152 |
return f"data:image/png;base64,{base64.b64encode(buf.getvalue()).decode('utf-8')}"
|
| 153 |
|
| 154 |
+
# Prepare the response data with images converted to base64
|
| 155 |
response_data = {
|
| 156 |
'ultimate_image': to_base64(final_results['ultimate_image']),
|
| 157 |
'fine_weave_image': to_base64(final_results['fine_weave_image']),
|
|
|
|
| 159 |
'creative_variation_image': to_base64(final_results['creative_variation_image']),
|
| 160 |
'mask_image': to_base64(mask)
|
| 161 |
}
|
| 162 |
+
|
| 163 |
return response_data
|