LogicGoInfotechSpaces commited on
Commit
42ba2b6
·
1 Parent(s): fd2b9a9

Make CCO model default for colorization - Change default model from GAN (dummy) to CCO (working) - Add fallback to GAN if CCO not available - Fix image format and saving issues

Browse files
Files changed (1) hide show
  1. app/main.py +35 -11
app/main.py CHANGED
@@ -122,6 +122,10 @@ def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"):
122
  if model is None:
123
  raise ValueError(f"CCO model '{model_name}' not loaded")
124
 
 
 
 
 
125
  # Convert PIL Image to numpy array
126
  oimg = np.asarray(img)
127
  if oimg.ndim == 2:
@@ -134,11 +138,15 @@ def colorize_image_cco(img: Image.Image, model_name: str = "eccv16"):
134
  with torch.no_grad():
135
  out_ab = model(tens_l_rs)
136
 
137
- # Postprocess output
138
  output_rgb = postprocess_tens(tens_l_orig, out_ab)
139
 
 
 
 
 
140
  # Convert numpy array back to PIL Image
141
- output_img = Image.fromarray((output_rgb * 255).astype(np.uint8))
142
  return output_img
143
 
144
  def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"):
@@ -298,7 +306,7 @@ async def colorize(
298
  user_id: Optional[str] = Form(None),
299
  category_id: Optional[str] = Form(None),
300
  categoryId: Optional[str] = Form(None),
301
- model: Optional[str] = Form("gan"), # New parameter: "gan", "cco", "cco-eccv16", "cco-siggraph17"
302
  ):
303
  import time
304
  start_time = time.time()
@@ -314,13 +322,22 @@ async def colorize(
314
  effective_category_id = None
315
 
316
  # Parse model parameter
317
- model_type = "gan" # Default
318
- cco_model = "eccv16" # Default for CCO
319
- model_type_for_log = "gan" # For MongoDB logging
 
 
 
 
 
320
 
321
  if model:
322
  model = model.strip().lower()
323
- if model == "cco" or model.startswith("cco-"):
 
 
 
 
324
  if not CCO_AVAILABLE:
325
  error_msg = "CCO models are not available"
326
  log_api_call(
@@ -353,9 +370,8 @@ async def colorize(
353
  cco_model = "eccv16"
354
  model_type_for_log = "cco-eccv16"
355
  else:
356
- # Default to "gan" for any other value
357
- model_type = "gan"
358
- model_type_for_log = "gan"
359
 
360
  if not file.content_type.startswith("image/"):
361
  error_msg = "Invalid file type"
@@ -380,13 +396,21 @@ async def colorize(
380
 
381
  try:
382
  img = Image.open(io.BytesIO(await file.read()))
 
 
 
 
383
  output_img = colorize_image(img, model_type=model_type, cco_model=cco_model)
 
 
 
 
384
 
385
  processing_time = time.time() - start_time
386
 
387
  result_id = f"{uuid.uuid4()}.jpg"
388
  output_path = os.path.join(RESULTS_DIR, result_id)
389
- output_img.save(output_path)
390
 
391
  base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
392
 
 
122
  if model is None:
123
  raise ValueError(f"CCO model '{model_name}' not loaded")
124
 
125
+ # Ensure image is RGB
126
+ if img.mode != "RGB":
127
+ img = img.convert("RGB")
128
+
129
  # Convert PIL Image to numpy array
130
  oimg = np.asarray(img)
131
  if oimg.ndim == 2:
 
138
  with torch.no_grad():
139
  out_ab = model(tens_l_rs)
140
 
141
+ # Postprocess output (returns RGB in [0, 1] range)
142
  output_rgb = postprocess_tens(tens_l_orig, out_ab)
143
 
144
+ # Clamp values to [0, 1] and convert to uint8
145
+ output_rgb = np.clip(output_rgb, 0, 1)
146
+ output_array = (output_rgb * 255).astype(np.uint8)
147
+
148
  # Convert numpy array back to PIL Image
149
+ output_img = Image.fromarray(output_array, 'RGB')
150
  return output_img
151
 
152
  def colorize_image(img: Image.Image, model_type: str = "gan", cco_model: str = "eccv16"):
 
306
  user_id: Optional[str] = Form(None),
307
  category_id: Optional[str] = Form(None),
308
  categoryId: Optional[str] = Form(None),
309
+ model: Optional[str] = Form(None), # Model parameter: "gan", "cco", "cco-eccv16", "cco-siggraph17" (default: CCO if available)
310
  ):
311
  import time
312
  start_time = time.time()
 
322
  effective_category_id = None
323
 
324
  # Parse model parameter
325
+ # Default to CCO if available, otherwise fallback to GAN
326
+ if CCO_AVAILABLE:
327
+ model_type = "cco"
328
+ cco_model = "eccv16"
329
+ model_type_for_log = "cco-eccv16"
330
+ else:
331
+ model_type = "gan"
332
+ model_type_for_log = "gan"
333
 
334
  if model:
335
  model = model.strip().lower()
336
+ if model == "gan":
337
+ # Use GAN model (dummy implementation - doesn't actually colorize)
338
+ model_type = "gan"
339
+ model_type_for_log = "gan"
340
+ elif model == "cco" or model.startswith("cco-"):
341
  if not CCO_AVAILABLE:
342
  error_msg = "CCO models are not available"
343
  log_api_call(
 
370
  cco_model = "eccv16"
371
  model_type_for_log = "cco-eccv16"
372
  else:
373
+ # Unknown model, use default (CCO if available, else GAN)
374
+ pass
 
375
 
376
  if not file.content_type.startswith("image/"):
377
  error_msg = "Invalid file type"
 
396
 
397
  try:
398
  img = Image.open(io.BytesIO(await file.read()))
399
+ # Ensure image is RGB
400
+ if img.mode != "RGB":
401
+ img = img.convert("RGB")
402
+
403
  output_img = colorize_image(img, model_type=model_type, cco_model=cco_model)
404
+
405
+ # Ensure output is RGB
406
+ if output_img.mode != "RGB":
407
+ output_img = output_img.convert("RGB")
408
 
409
  processing_time = time.time() - start_time
410
 
411
  result_id = f"{uuid.uuid4()}.jpg"
412
  output_path = os.path.join(RESULTS_DIR, result_id)
413
+ output_img.save(output_path, "JPEG", quality=95)
414
 
415
  base_url = "https://logicgoinfotechspaces-text-guided-image-colorization.hf.space"
416