Commit
·
42ba2b6
1
Parent(s):
fd2b9a9
Make CCO model default for colorization - Change default model from GAN (dummy) to CCO (working) - Add fallback to GAN if CCO not available - Fix image format and saving issues
Browse files- app/main.py +35 -11
app/main.py
CHANGED
|
@@ -122,6 +122,10 @@ def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"):
|
|
| 122 |
if model is None:
|
| 123 |
raise ValueError(f"CCO model '{model_name}' not loaded")
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
# Convert PIL Image to numpy array
|
| 126 |
oimg = np.asarray(img)
|
| 127 |
if oimg.ndim == 2:
|
|
@@ -134,11 +138,15 @@ def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"):
|
|
| 134 |
with torch.no_grad():
|
| 135 |
out_ab = model(tens_l_rs)
|
| 136 |
|
| 137 |
-
# Postprocess output
|
| 138 |
output_rgb = postprocess_tens(tens_l_orig, out_ab)
|
| 139 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 140 |
# Convert numpy array back to PIL Image
|
| 141 |
-
output_img = Image.fromarray(
|
| 142 |
return output_img
|
| 143 |
|
| 144 |
def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"):
|
|
@@ -298,7 +306,7 @@ async def colorize(
|
|
| 298 |
user_id: Optional[str] = Form(None),
|
| 299 |
category_id: Optional[str] = Form(None),
|
| 300 |
categoryId: Optional[str] = Form(None),
|
| 301 |
-
model: Optional[str] = Form(
|
| 302 |
):
|
| 303 |
import time
|
| 304 |
start_time = time.time()
|
|
@@ -314,13 +322,22 @@ async def colorize(
|
|
| 314 |
effective_category_id = None
|
| 315 |
|
| 316 |
# Parse model parameter
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 320 |
|
| 321 |
if model:
|
| 322 |
model = model.strip().lower()
|
| 323 |
-
if model == "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
if not CCO_AVAILABLE:
|
| 325 |
error_msg = "CCO models are not available"
|
| 326 |
log_api_call(
|
|
@@ -353,9 +370,8 @@ async def colorize(
|
|
| 353 |
cco_model = "eccv16"
|
| 354 |
model_type_for_log = "cco-eccv16"
|
| 355 |
else:
|
| 356 |
-
#
|
| 357 |
-
|
| 358 |
-
model_type_for_log = "gan"
|
| 359 |
|
| 360 |
if not file.content_type.startswith("image/"):
|
| 361 |
error_msg = "Invalid file type"
|
|
@@ -380,13 +396,21 @@ async def colorize(
|
|
| 380 |
|
| 381 |
try:
|
| 382 |
img = Image.open(io.BytesIO(await file.read()))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 383 |
output_img = colorize_image(img, model_type=model_type, cco_model=cco_model)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 384 |
|
| 385 |
processing_time = time.time() - start_time
|
| 386 |
|
| 387 |
result_id = f"{uuid.uuid4()}.jpg"
|
| 388 |
output_path = os.path.join(RESULTS_DIR, result_id)
|
| 389 |
-
output_img.save(output_path)
|
| 390 |
|
| 391 |
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
|
| 392 |
|
|
|
|
| 122 |
if model is None:
|
| 123 |
raise ValueError(f"CCO model '{model_name}' not loaded")
|
| 124 |
|
| 125 |
+
# Ensure image is RGB
|
| 126 |
+
if img.mode != "RGB":
|
| 127 |
+
img = img.convert("RGB")
|
| 128 |
+
|
| 129 |
# Convert PIL Image to numpy array
|
| 130 |
oimg = np.asarray(img)
|
| 131 |
if oimg.ndim == 2:
|
|
|
|
| 138 |
with torch.no_grad():
|
| 139 |
out_ab = model(tens_l_rs)
|
| 140 |
|
| 141 |
+
# Postprocess output (returns RGB in [0, 1] range)
|
| 142 |
output_rgb = postprocess_tens(tens_l_orig, out_ab)
|
| 143 |
|
| 144 |
+
# Clamp values to [0, 1] and convert to uint8
|
| 145 |
+
output_rgb = np.clip(output_rgb, 0, 1)
|
| 146 |
+
output_array = (output_rgb * 255).astype(np.uint8)
|
| 147 |
+
|
| 148 |
# Convert numpy array back to PIL Image
|
| 149 |
+
output_img = Image.fromarray(output_array, 'RGB')
|
| 150 |
return output_img
|
| 151 |
|
| 152 |
def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"):
|
|
|
|
| 306 |
user_id: Optional[str] = Form(None),
|
| 307 |
category_id: Optional[str] = Form(None),
|
| 308 |
categoryId: Optional[str] = Form(None),
|
| 309 |
+
model: Optional[str] = Form(None), # Model parameter: "gan", "cco", "cco-eccv16", "cco-siggraph17" (default: CCO if available)
|
| 310 |
):
|
| 311 |
import time
|
| 312 |
start_time = time.time()
|
|
|
|
| 322 |
effective_category_id = None
|
| 323 |
|
| 324 |
# Parse model parameter
|
| 325 |
+
# Default to CCO if available, otherwise fallback to GAN
|
| 326 |
+
if CCO_AVAILABLE:
|
| 327 |
+
model_type = "cco"
|
| 328 |
+
cco_model = "eccv16"
|
| 329 |
+
model_type_for_log = "cco-eccv16"
|
| 330 |
+
else:
|
| 331 |
+
model_type = "gan"
|
| 332 |
+
model_type_for_log = "gan"
|
| 333 |
|
| 334 |
if model:
|
| 335 |
model = model.strip().lower()
|
| 336 |
+
if model == "gan":
|
| 337 |
+
# Use GAN model (dummy implementation - doesn't actually colorize)
|
| 338 |
+
model_type = "gan"
|
| 339 |
+
model_type_for_log = "gan"
|
| 340 |
+
elif model == "cco" or model.startswith("cco-"):
|
| 341 |
if not CCO_AVAILABLE:
|
| 342 |
error_msg = "CCO models are not available"
|
| 343 |
log_api_call(
|
|
|
|
| 370 |
cco_model = "eccv16"
|
| 371 |
model_type_for_log = "cco-eccv16"
|
| 372 |
else:
|
| 373 |
+
# Unknown model, use default (CCO if available, else GAN)
|
| 374 |
+
pass
|
|
|
|
| 375 |
|
| 376 |
if not file.content_type.startswith("image/"):
|
| 377 |
error_msg = "Invalid file type"
|
|
|
|
| 396 |
|
| 397 |
try:
|
| 398 |
img = Image.open(io.BytesIO(await file.read()))
|
| 399 |
+
# Ensure image is RGB
|
| 400 |
+
if img.mode != "RGB":
|
| 401 |
+
img = img.convert("RGB")
|
| 402 |
+
|
| 403 |
output_img = colorize_image(img, model_type=model_type, cco_model=cco_model)
|
| 404 |
+
|
| 405 |
+
# Ensure output is RGB
|
| 406 |
+
if output_img.mode != "RGB":
|
| 407 |
+
output_img = output_img.convert("RGB")
|
| 408 |
|
| 409 |
processing_time = time.time() - start_time
|
| 410 |
|
| 411 |
result_id = f"{uuid.uuid4()}.jpg"
|
| 412 |
output_path = os.path.join(RESULTS_DIR, result_id)
|
| 413 |
+
output_img.save(output_path, "JPEG", quality=95)
|
| 414 |
|
| 415 |
base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
|
| 416 |
|