vidhi0405 commited on
Commit
3d467cc
·
1 Parent(s): 4864dad

separate endpoints

Browse files
Files changed (1) hide show
  1. app.py +111 -15
app.py CHANGED
@@ -261,8 +261,8 @@ def summarize_captions(captions: list[str]) -> str:
261
  with torch.no_grad():
262
  output_ids = model.generate(
263
  **inputs,
264
- max_length=300,
265
- min_length=50,
266
  length_penalty=2.0,
267
  num_beams=4,
268
  early_stopping=True,
@@ -274,7 +274,7 @@ def summarize_captions(captions: list[str]) -> str:
274
  return _finalize_caption(summary, max_sentences=10)
275
 
276
 
277
- def generate_caption_text(image: Image.Image) -> str:
278
  runtime_model, runtime_processor = _get_caption_runtime()
279
  model_device = str(next(runtime_model.parameters()).device)
280
 
@@ -300,7 +300,7 @@ def generate_caption_text(image: Image.Image) -> str:
300
  )
301
 
302
  try:
303
- inputs = _build_inputs(CAPTION_PROMPT)
304
  except Exception as exc:
305
  if "Mismatch in `image` token count" not in str(exc):
306
  raise AppError("Failed to preprocess image for captioning.", 422) from exc
@@ -324,10 +324,10 @@ def generate_caption_text(image: Image.Image) -> str:
324
  return _finalize_caption(caption)
325
 
326
 
327
- def generate_caption_text_safe(image: Image.Image) -> str:
328
  global _caption_model, _caption_processor, _caption_force_cpu
329
  try:
330
- return generate_caption_text(image)
331
  except Exception as exc:
332
  msg = str(exc)
333
  if "CUDA error" not in msg and "device-side assert" not in msg:
@@ -344,7 +344,7 @@ def generate_caption_text_safe(image: Image.Image) -> str:
344
  except Exception:
345
  pass
346
 
347
- return generate_caption_text(image)
348
 
349
 
350
  def insert_record(collection, payload: dict) -> str:
@@ -355,10 +355,7 @@ def insert_record(collection, payload: dict) -> str:
355
  raise AppError("MongoDB insert failed.", 503) from exc
356
 
357
 
358
- @app.post("/generate-caption")
359
- async def generate_caption(request: Request):
360
- _ensure_db_ready()
361
-
362
  try:
363
  form = await request.form()
364
  except Exception as exc:
@@ -381,8 +378,8 @@ async def generate_caption(request: Request):
381
  if len(uploads) > MAX_IMAGES:
382
  raise AppError("You can upload a maximum of 5 images.", 400)
383
 
384
- image_captions = []
385
- for upload in uploads:
386
  if upload.content_type and not upload.content_type.startswith("image/"):
387
  raise AppError("All uploaded files must be images.", 400)
388
 
@@ -397,11 +394,23 @@ async def generate_caption(request: Request):
397
  except OSError as exc:
398
  raise AppError("Unable to read one of the uploaded images.", 400) from exc
399
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  caption = generate_caption_text_safe(image)
401
  if not caption:
402
  raise AppError("Caption generation produced empty text.", 500)
403
-
404
- image_captions.append({"filename": upload.filename, "caption": caption})
405
 
406
  caption_texts = [x["caption"] for x in image_captions]
407
  caption = summarize_captions(caption_texts)
@@ -424,3 +433,90 @@ async def generate_caption(request: Request):
424
  response_data["created_at"] = response_data["created_at"].isoformat()
425
 
426
  return ok("Caption generated successfully.", response_data)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  with torch.no_grad():
262
  output_ids = model.generate(
263
  **inputs,
264
+ max_length=512,
265
+ min_length=100,
266
  length_penalty=2.0,
267
  num_beams=4,
268
  early_stopping=True,
 
274
  return _finalize_caption(summary, max_sentences=10)
275
 
276
 
277
+ def generate_caption_text(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str:
278
  runtime_model, runtime_processor = _get_caption_runtime()
279
  model_device = str(next(runtime_model.parameters()).device)
280
 
 
300
  )
301
 
302
  try:
303
+ inputs = _build_inputs(prompt)
304
  except Exception as exc:
305
  if "Mismatch in `image` token count" not in str(exc):
306
  raise AppError("Failed to preprocess image for captioning.", 422) from exc
 
324
  return _finalize_caption(caption)
325
 
326
 
327
+ def generate_caption_text_safe(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str:
328
  global _caption_model, _caption_processor, _caption_force_cpu
329
  try:
330
+ return generate_caption_text(image, prompt)
331
  except Exception as exc:
332
  msg = str(exc)
333
  if "CUDA error" not in msg and "device-side assert" not in msg:
 
344
  except Exception:
345
  pass
346
 
347
+ return generate_caption_text(image, prompt)
348
 
349
 
350
  def insert_record(collection, payload: dict) -> str:
 
355
  raise AppError("MongoDB insert failed.", 503) from exc
356
 
357
 
358
+ async def _parse_images(request: Request) -> list[tuple[str, Image.Image]]:
 
 
 
359
  try:
360
  form = await request.form()
361
  except Exception as exc:
 
378
  if len(uploads) > MAX_IMAGES:
379
  raise AppError("You can upload a maximum of 5 images.", 400)
380
 
381
+ parsed_images = []
382
+ for i, upload in enumerate(uploads):
383
  if upload.content_type and not upload.content_type.startswith("image/"):
384
  raise AppError("All uploaded files must be images.", 400)
385
 
 
394
  except OSError as exc:
395
  raise AppError("Unable to read one of the uploaded images.", 400) from exc
396
 
397
+ filename = upload.filename or f"image_{i+1}"
398
+ parsed_images.append((filename, image))
399
+
400
+ return parsed_images
401
+
402
+
403
+ @app.post("/generate-caption-summary")
404
+ async def generate_caption_summary(request: Request):
405
+ _ensure_db_ready()
406
+ images = await _parse_images(request)
407
+
408
+ image_captions = []
409
+ for filename, image in images:
410
  caption = generate_caption_text_safe(image)
411
  if not caption:
412
  raise AppError("Caption generation produced empty text.", 500)
413
+ image_captions.append({"filename": filename, "caption": caption})
 
414
 
415
  caption_texts = [x["caption"] for x in image_captions]
416
  caption = summarize_captions(caption_texts)
 
433
  response_data["created_at"] = response_data["created_at"].isoformat()
434
 
435
  return ok("Caption generated successfully.", response_data)
436
+
437
+
438
+ @app.post("/generate-caption-collage")
439
+ async def generate_caption_collage(request: Request):
440
+ _ensure_db_ready()
441
+ images = await _parse_images(request)
442
+
443
+ # Create collage (horizontal strip, resized to height 512 for consistency)
444
+ resized_images = []
445
+ target_height = 512
446
+ for _, img in images:
447
+ aspect_ratio = img.width / img.height
448
+ new_width = int(target_height * aspect_ratio)
449
+ resized_images.append(img.resize((new_width, target_height), Image.Resampling.LANCZOS))
450
+
451
+ total_width = sum(img.width for img in resized_images)
452
+ collage = Image.new("RGB", (total_width, target_height))
453
+ x_offset = 0
454
+ for img in resized_images:
455
+ collage.paste(img, (x_offset, 0))
456
+ x_offset += img.width
457
+
458
+ caption = generate_caption_text_safe(collage)
459
+ if not caption:
460
+ raise AppError("Collage caption generation produced empty text.", 500)
461
+
462
+ # For database storage, we list source filenames but the 'image_captions'
463
+ # will just contain the single collage caption to avoid confusion.
464
+ source_filenames = [fname for fname, _ in images]
465
+
466
+ mongo_payload = {
467
+ "caption": caption,
468
+ "source_filenames": source_filenames,
469
+ "image_captions": [{"filename": "collage", "caption": caption}],
470
+ "images_count": len(images),
471
+ "is_summarized": False, # It's a direct caption of a collage
472
+ "created_at": datetime.now(timezone.utc),
473
+ }
474
+
475
+ audio_file_id = insert_record(caption_collection, mongo_payload)
476
+
477
+ response_data = {**mongo_payload, "audio_file_id": audio_file_id}
478
+ response_data.pop("_id", None)
479
+ response_data["created_at"] = response_data["created_at"].isoformat()
480
+
481
+ return ok("Collage caption generated successfully.", response_data)
482
+
483
+
484
+ @app.post("/generate-caption-context")
485
+ async def generate_caption_context(request: Request):
486
+ _ensure_db_ready()
487
+ images = await _parse_images(request)
488
+
489
+ image_captions = []
490
+ previous_context = ""
491
+
492
+ for i, (filename, image) in enumerate(images):
493
+ prompt = CAPTION_PROMPT
494
+ if i > 0 and previous_context:
495
+ prompt = f"Context from previous image: {previous_context}. {CAPTION_PROMPT}"
496
+
497
+ caption = generate_caption_text_safe(image, prompt=prompt)
498
+ if not caption:
499
+ caption = "No caption generated."
500
+
501
+ image_captions.append({"filename": filename, "caption": caption})
502
+ previous_context = caption
503
+
504
+ # Combine captions for the main 'caption' field
505
+ full_text = " ".join([ic["caption"] for ic in image_captions])
506
+
507
+ mongo_payload = {
508
+ "caption": full_text,
509
+ "source_filenames": [fname for fname, _ in images],
510
+ "image_captions": image_captions,
511
+ "images_count": len(images),
512
+ "is_summarized": False,
513
+ "created_at": datetime.now(timezone.utc),
514
+ }
515
+
516
+ audio_file_id = insert_record(caption_collection, mongo_payload)
517
+
518
+ response_data = {**mongo_payload, "audio_file_id": audio_file_id}
519
+ response_data.pop("_id", None)
520
+ response_data["created_at"] = response_data["created_at"].isoformat()
521
+
522
+ return ok("Contextual captions generated successfully.", response_data)