Spaces:
Sleeping
Sleeping
separate endpoints
Browse files
app.py
CHANGED
|
@@ -261,8 +261,8 @@ def summarize_captions(captions: list[str]) -> str:
|
|
| 261 |
with torch.no_grad():
|
| 262 |
output_ids = model.generate(
|
| 263 |
**inputs,
|
| 264 |
-
max_length=
|
| 265 |
-
min_length=
|
| 266 |
length_penalty=2.0,
|
| 267 |
num_beams=4,
|
| 268 |
early_stopping=True,
|
|
@@ -274,7 +274,7 @@ def summarize_captions(captions: list[str]) -> str:
|
|
| 274 |
return _finalize_caption(summary, max_sentences=10)
|
| 275 |
|
| 276 |
|
| 277 |
-
def generate_caption_text(image: Image.Image) -> str:
|
| 278 |
runtime_model, runtime_processor = _get_caption_runtime()
|
| 279 |
model_device = str(next(runtime_model.parameters()).device)
|
| 280 |
|
|
@@ -300,7 +300,7 @@ def generate_caption_text(image: Image.Image) -> str:
|
|
| 300 |
)
|
| 301 |
|
| 302 |
try:
|
| 303 |
-
inputs = _build_inputs(
|
| 304 |
except Exception as exc:
|
| 305 |
if "Mismatch in `image` token count" not in str(exc):
|
| 306 |
raise AppError("Failed to preprocess image for captioning.", 422) from exc
|
|
@@ -324,10 +324,10 @@ def generate_caption_text(image: Image.Image) -> str:
|
|
| 324 |
return _finalize_caption(caption)
|
| 325 |
|
| 326 |
|
| 327 |
-
def generate_caption_text_safe(image: Image.Image) -> str:
|
| 328 |
global _caption_model, _caption_processor, _caption_force_cpu
|
| 329 |
try:
|
| 330 |
-
return generate_caption_text(image)
|
| 331 |
except Exception as exc:
|
| 332 |
msg = str(exc)
|
| 333 |
if "CUDA error" not in msg and "device-side assert" not in msg:
|
|
@@ -344,7 +344,7 @@ def generate_caption_text_safe(image: Image.Image) -> str:
|
|
| 344 |
except Exception:
|
| 345 |
pass
|
| 346 |
|
| 347 |
-
return generate_caption_text(image)
|
| 348 |
|
| 349 |
|
| 350 |
def insert_record(collection, payload: dict) -> str:
|
|
@@ -355,10 +355,7 @@ def insert_record(collection, payload: dict) -> str:
|
|
| 355 |
raise AppError("MongoDB insert failed.", 503) from exc
|
| 356 |
|
| 357 |
|
| 358 |
-
|
| 359 |
-
async def generate_caption(request: Request):
|
| 360 |
-
_ensure_db_ready()
|
| 361 |
-
|
| 362 |
try:
|
| 363 |
form = await request.form()
|
| 364 |
except Exception as exc:
|
|
@@ -381,8 +378,8 @@ async def generate_caption(request: Request):
|
|
| 381 |
if len(uploads) > MAX_IMAGES:
|
| 382 |
raise AppError("You can upload a maximum of 5 images.", 400)
|
| 383 |
|
| 384 |
-
|
| 385 |
-
for upload in uploads:
|
| 386 |
if upload.content_type and not upload.content_type.startswith("image/"):
|
| 387 |
raise AppError("All uploaded files must be images.", 400)
|
| 388 |
|
|
@@ -397,11 +394,23 @@ async def generate_caption(request: Request):
|
|
| 397 |
except OSError as exc:
|
| 398 |
raise AppError("Unable to read one of the uploaded images.", 400) from exc
|
| 399 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
caption = generate_caption_text_safe(image)
|
| 401 |
if not caption:
|
| 402 |
raise AppError("Caption generation produced empty text.", 500)
|
| 403 |
-
|
| 404 |
-
image_captions.append({"filename": upload.filename, "caption": caption})
|
| 405 |
|
| 406 |
caption_texts = [x["caption"] for x in image_captions]
|
| 407 |
caption = summarize_captions(caption_texts)
|
|
@@ -424,3 +433,90 @@ async def generate_caption(request: Request):
|
|
| 424 |
response_data["created_at"] = response_data["created_at"].isoformat()
|
| 425 |
|
| 426 |
return ok("Caption generated successfully.", response_data)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
with torch.no_grad():
|
| 262 |
output_ids = model.generate(
|
| 263 |
**inputs,
|
| 264 |
+
max_length=512,
|
| 265 |
+
min_length=100,
|
| 266 |
length_penalty=2.0,
|
| 267 |
num_beams=4,
|
| 268 |
early_stopping=True,
|
|
|
|
| 274 |
return _finalize_caption(summary, max_sentences=10)
|
| 275 |
|
| 276 |
|
| 277 |
+
def generate_caption_text(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str:
|
| 278 |
runtime_model, runtime_processor = _get_caption_runtime()
|
| 279 |
model_device = str(next(runtime_model.parameters()).device)
|
| 280 |
|
|
|
|
| 300 |
)
|
| 301 |
|
| 302 |
try:
|
| 303 |
+
inputs = _build_inputs(prompt)
|
| 304 |
except Exception as exc:
|
| 305 |
if "Mismatch in `image` token count" not in str(exc):
|
| 306 |
raise AppError("Failed to preprocess image for captioning.", 422) from exc
|
|
|
|
| 324 |
return _finalize_caption(caption)
|
| 325 |
|
| 326 |
|
| 327 |
+
def generate_caption_text_safe(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str:
|
| 328 |
global _caption_model, _caption_processor, _caption_force_cpu
|
| 329 |
try:
|
| 330 |
+
return generate_caption_text(image, prompt)
|
| 331 |
except Exception as exc:
|
| 332 |
msg = str(exc)
|
| 333 |
if "CUDA error" not in msg and "device-side assert" not in msg:
|
|
|
|
| 344 |
except Exception:
|
| 345 |
pass
|
| 346 |
|
| 347 |
+
return generate_caption_text(image, prompt)
|
| 348 |
|
| 349 |
|
| 350 |
def insert_record(collection, payload: dict) -> str:
|
|
|
|
| 355 |
raise AppError("MongoDB insert failed.", 503) from exc
|
| 356 |
|
| 357 |
|
| 358 |
+
async def _parse_images(request: Request) -> list[tuple[str, Image.Image]]:
|
|
|
|
|
|
|
|
|
|
| 359 |
try:
|
| 360 |
form = await request.form()
|
| 361 |
except Exception as exc:
|
|
|
|
| 378 |
if len(uploads) > MAX_IMAGES:
|
| 379 |
raise AppError("You can upload a maximum of 5 images.", 400)
|
| 380 |
|
| 381 |
+
parsed_images = []
|
| 382 |
+
for i, upload in enumerate(uploads):
|
| 383 |
if upload.content_type and not upload.content_type.startswith("image/"):
|
| 384 |
raise AppError("All uploaded files must be images.", 400)
|
| 385 |
|
|
|
|
| 394 |
except OSError as exc:
|
| 395 |
raise AppError("Unable to read one of the uploaded images.", 400) from exc
|
| 396 |
|
| 397 |
+
filename = upload.filename or f"image_{i+1}"
|
| 398 |
+
parsed_images.append((filename, image))
|
| 399 |
+
|
| 400 |
+
return parsed_images
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
@app.post("/generate-caption-summary")
|
| 404 |
+
async def generate_caption_summary(request: Request):
|
| 405 |
+
_ensure_db_ready()
|
| 406 |
+
images = await _parse_images(request)
|
| 407 |
+
|
| 408 |
+
image_captions = []
|
| 409 |
+
for filename, image in images:
|
| 410 |
caption = generate_caption_text_safe(image)
|
| 411 |
if not caption:
|
| 412 |
raise AppError("Caption generation produced empty text.", 500)
|
| 413 |
+
image_captions.append({"filename": filename, "caption": caption})
|
|
|
|
| 414 |
|
| 415 |
caption_texts = [x["caption"] for x in image_captions]
|
| 416 |
caption = summarize_captions(caption_texts)
|
|
|
|
| 433 |
response_data["created_at"] = response_data["created_at"].isoformat()
|
| 434 |
|
| 435 |
return ok("Caption generated successfully.", response_data)
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
@app.post("/generate-caption-collage")
|
| 439 |
+
async def generate_caption_collage(request: Request):
|
| 440 |
+
_ensure_db_ready()
|
| 441 |
+
images = await _parse_images(request)
|
| 442 |
+
|
| 443 |
+
# Create collage (horizontal strip, resized to height 512 for consistency)
|
| 444 |
+
resized_images = []
|
| 445 |
+
target_height = 512
|
| 446 |
+
for _, img in images:
|
| 447 |
+
aspect_ratio = img.width / img.height
|
| 448 |
+
new_width = int(target_height * aspect_ratio)
|
| 449 |
+
resized_images.append(img.resize((new_width, target_height), Image.Resampling.LANCZOS))
|
| 450 |
+
|
| 451 |
+
total_width = sum(img.width for img in resized_images)
|
| 452 |
+
collage = Image.new("RGB", (total_width, target_height))
|
| 453 |
+
x_offset = 0
|
| 454 |
+
for img in resized_images:
|
| 455 |
+
collage.paste(img, (x_offset, 0))
|
| 456 |
+
x_offset += img.width
|
| 457 |
+
|
| 458 |
+
caption = generate_caption_text_safe(collage)
|
| 459 |
+
if not caption:
|
| 460 |
+
raise AppError("Collage caption generation produced empty text.", 500)
|
| 461 |
+
|
| 462 |
+
# For database storage, we list source filenames but the 'image_captions'
|
| 463 |
+
# will just contain the single collage caption to avoid confusion.
|
| 464 |
+
source_filenames = [fname for fname, _ in images]
|
| 465 |
+
|
| 466 |
+
mongo_payload = {
|
| 467 |
+
"caption": caption,
|
| 468 |
+
"source_filenames": source_filenames,
|
| 469 |
+
"image_captions": [{"filename": "collage", "caption": caption}],
|
| 470 |
+
"images_count": len(images),
|
| 471 |
+
"is_summarized": False, # It's a direct caption of a collage
|
| 472 |
+
"created_at": datetime.now(timezone.utc),
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
audio_file_id = insert_record(caption_collection, mongo_payload)
|
| 476 |
+
|
| 477 |
+
response_data = {**mongo_payload, "audio_file_id": audio_file_id}
|
| 478 |
+
response_data.pop("_id", None)
|
| 479 |
+
response_data["created_at"] = response_data["created_at"].isoformat()
|
| 480 |
+
|
| 481 |
+
return ok("Collage caption generated successfully.", response_data)
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@app.post("/generate-caption-context")
|
| 485 |
+
async def generate_caption_context(request: Request):
|
| 486 |
+
_ensure_db_ready()
|
| 487 |
+
images = await _parse_images(request)
|
| 488 |
+
|
| 489 |
+
image_captions = []
|
| 490 |
+
previous_context = ""
|
| 491 |
+
|
| 492 |
+
for i, (filename, image) in enumerate(images):
|
| 493 |
+
prompt = CAPTION_PROMPT
|
| 494 |
+
if i > 0 and previous_context:
|
| 495 |
+
prompt = f"Context from previous image: {previous_context}. {CAPTION_PROMPT}"
|
| 496 |
+
|
| 497 |
+
caption = generate_caption_text_safe(image, prompt=prompt)
|
| 498 |
+
if not caption:
|
| 499 |
+
caption = "No caption generated."
|
| 500 |
+
|
| 501 |
+
image_captions.append({"filename": filename, "caption": caption})
|
| 502 |
+
previous_context = caption
|
| 503 |
+
|
| 504 |
+
# Combine captions for the main 'caption' field
|
| 505 |
+
full_text = " ".join([ic["caption"] for ic in image_captions])
|
| 506 |
+
|
| 507 |
+
mongo_payload = {
|
| 508 |
+
"caption": full_text,
|
| 509 |
+
"source_filenames": [fname for fname, _ in images],
|
| 510 |
+
"image_captions": image_captions,
|
| 511 |
+
"images_count": len(images),
|
| 512 |
+
"is_summarized": False,
|
| 513 |
+
"created_at": datetime.now(timezone.utc),
|
| 514 |
+
}
|
| 515 |
+
|
| 516 |
+
audio_file_id = insert_record(caption_collection, mongo_payload)
|
| 517 |
+
|
| 518 |
+
response_data = {**mongo_payload, "audio_file_id": audio_file_id}
|
| 519 |
+
response_data.pop("_id", None)
|
| 520 |
+
response_data["created_at"] = response_data["created_at"].isoformat()
|
| 521 |
+
|
| 522 |
+
return ok("Contextual captions generated successfully.", response_data)
|