Spaces:
Sleeping
Sleeping
| import io | |
| import logging | |
| import os | |
| import re | |
| import threading | |
| from datetime import datetime, timezone | |
| # Avoid invalid OMP setting from runtime environment (e.g. empty/non-numeric). | |
| _omp_threads = os.getenv("OMP_NUM_THREADS", "").strip() | |
| if not _omp_threads.isdigit() or int(_omp_threads) < 1: | |
| os.environ["OMP_NUM_THREADS"] = "8" | |
| import torch | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, Request, UploadFile | |
| from fastapi.exceptions import RequestValidationError | |
| from fastapi.responses import JSONResponse | |
| from PIL import Image, UnidentifiedImageError | |
| from pymongo import MongoClient | |
| from pymongo.errors import PyMongoError, ServerSelectionTimeoutError | |
| from starlette.datastructures import UploadFile as StarletteUploadFile | |
| from transformers import ( | |
| AutoModelForImageTextToText, | |
| AutoModelForSeq2SeqLM, | |
| AutoProcessor, | |
| AutoTokenizer, | |
| ) | |
| load_dotenv() | |
| CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "vidhi0405/Qwen_I2T") | |
| SUMMARIZER_MODEL_ID = os.getenv("SUMMARIZER_MODEL_ID", "facebook/bart-large-cnn") | |
| DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu") | |
| DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| MAX_NEW_TOKENS = 120 | |
| MAX_IMAGES = 5 | |
| MONGO_URI = (os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or "").strip().strip('"').strip("'") | |
| MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech") | |
| CAPTION_PROMPT = ( | |
| "Act as a professional news reporter delivering a live on-scene report in real time. " | |
| "Speak naturally, as if you are addressing viewers who are watching this unfold right now. " | |
| "Describe the scene in 3 to 4 complete, vivid sentences. " | |
| "Mention what is happening, the surrounding environment, and the overall mood, " | |
| "and convey the urgency or emotion of the moment when appropriate." | |
| ) | |
| CAPTION_RETRY_PROMPT = ( | |
| "Describe this image in 2 to 3 complete sentences. " | |
| "Mention the main subject, action, environment, and mood." | |
| ) | |
| CAPTION_MIN_SENTENCES = 3 | |
| CAPTION_MAX_SENTENCES = 4 | |
| PROCESSOR_MAX_LENGTH = 8192 | |
| logger = logging.getLogger(__name__) | |
| def ok(message: str, data): | |
| return JSONResponse( | |
| status_code=200, | |
| content={"success": True, "message": message, "data": data}, | |
| ) | |
| def fail(message: str, status_code: int = 400): | |
| return JSONResponse( | |
| status_code=status_code, | |
| content={"success": False, "message": message, "data": None}, | |
| ) | |
| class AppError(Exception): | |
| def __init__(self, message: str, status_code: int = 400): | |
| super().__init__(message) | |
| self.message = message | |
| self.status_code = status_code | |
| torch.set_num_threads(8) | |
| _caption_model = None | |
| _caption_processor = None | |
| _caption_lock = threading.Lock() | |
| _caption_force_cpu = False | |
| _summarizer_model = None | |
| _summarizer_tokenizer = None | |
| _summarizer_lock = threading.Lock() | |
| app = FastAPI(title="Image to Text API") | |
| mongo_client = None | |
| mongo_db = None | |
| caption_collection = None | |
| db_init_error = None | |
| if not MONGO_URI: | |
| db_init_error = "MONGO_URI (or MONGODB_URI) is not set." | |
| else: | |
| try: | |
| mongo_client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000) | |
| mongo_client.admin.command("ping") | |
| mongo_db = mongo_client[MONGO_DB_NAME] | |
| caption_collection = mongo_db["captions"] | |
| except ServerSelectionTimeoutError: | |
| db_init_error = "Unable to connect to MongoDB (timeout)." | |
| except PyMongoError as exc: | |
| db_init_error = "Unable to initialize MongoDB: {}".format(exc) | |
| def root(): | |
| return { | |
| "success": True, | |
| "message": "Use POST /generate-caption with form-data key 'file' or 'files' (up to 5 images).", | |
| "data": None, | |
| } | |
| def health(): | |
| if db_init_error: | |
| return { | |
| "success": False, | |
| "message": db_init_error, | |
| "data": { | |
| "caption_model_id": CAPTION_MODEL_ID, | |
| "summarizer_model_id": SUMMARIZER_MODEL_ID, | |
| }, | |
| } | |
| return { | |
| "success": True, | |
| "message": "ok", | |
| "data": { | |
| "caption_model_id": CAPTION_MODEL_ID, | |
| "summarizer_model_id": SUMMARIZER_MODEL_ID, | |
| }, | |
| } | |
| async def preload_runtime_models(): | |
| if os.getenv("PRELOAD_MODELS", "1").strip().lower() in {"0", "false", "no"}: | |
| logger.info("Model preloading disabled via PRELOAD_MODELS.") | |
| return | |
| try: | |
| _get_caption_runtime() | |
| logger.info("Caption model preloaded successfully.") | |
| except Exception as exc: | |
| logger.warning("Caption model preload failed: %s", exc) | |
| try: | |
| _get_summarizer_runtime() | |
| logger.info("Summarizer model preloaded successfully.") | |
| except Exception as exc: | |
| logger.warning("Summarizer model preload failed: %s", exc) | |
| async def app_error_handler(_, exc: AppError): | |
| return fail(exc.message, exc.status_code) | |
| async def validation_error_handler(_, exc: RequestValidationError): | |
| return fail("Invalid request payload.", 422) | |
| async def unhandled_error_handler(_, exc: Exception): | |
| logger.exception("Unhandled server error: %s", exc) | |
| return fail("Internal server error.", 500) | |
| def _ensure_db_ready(): | |
| if db_init_error: | |
| raise AppError(db_init_error, 503) | |
| def _finalize_caption(raw_text: str, max_sentences: int = CAPTION_MAX_SENTENCES) -> str: | |
| text = " ".join(raw_text.split()).strip() | |
| if not text: | |
| return "" | |
| sentences = re.findall(r"[^.!?]+[.!?]", text) | |
| sentences = [s.strip() for s in sentences if s.strip()] | |
| if len(sentences) >= CAPTION_MIN_SENTENCES: | |
| return " ".join(sentences[:max_sentences]).strip() | |
| if text and text[-1] not in ".!?": | |
| text = re.sub(r"[,:;\-]\s*[^,:;\-]*$", "", text).strip() | |
| return text | |
| def _get_caption_runtime(): | |
| global _caption_model, _caption_processor, _caption_force_cpu | |
| if _caption_model is not None and _caption_processor is not None: | |
| return _caption_model, _caption_processor | |
| with _caption_lock: | |
| if _caption_model is None or _caption_processor is None: | |
| device = "cpu" if _caption_force_cpu else DEVICE | |
| dtype = torch.float32 if device == "cpu" else DTYPE | |
| try: | |
| loaded_model = AutoModelForImageTextToText.from_pretrained( | |
| CAPTION_MODEL_ID, | |
| trust_remote_code=True, | |
| torch_dtype=dtype, | |
| low_cpu_mem_usage=True, | |
| ).to(device) | |
| loaded_processor = AutoProcessor.from_pretrained( | |
| CAPTION_MODEL_ID, | |
| trust_remote_code=True, | |
| ) | |
| except Exception as exc: | |
| raise AppError("Failed to load caption model.", 503) from exc | |
| loaded_model.eval() | |
| _caption_model = loaded_model | |
| _caption_processor = loaded_processor | |
| return _caption_model, _caption_processor | |
| def _get_summarizer_runtime(): | |
| global _summarizer_model, _summarizer_tokenizer | |
| if _summarizer_model is not None and _summarizer_tokenizer is not None: | |
| return _summarizer_model, _summarizer_tokenizer | |
| with _summarizer_lock: | |
| if _summarizer_model is None or _summarizer_tokenizer is None: | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL_ID) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL_ID, torch_dtype=DTYPE).to(DEVICE) | |
| except Exception as exc: | |
| raise AppError("Failed to load summarization model.", 503) from exc | |
| model.eval() | |
| _summarizer_tokenizer = tokenizer | |
| _summarizer_model = model | |
| return _summarizer_model, _summarizer_tokenizer | |
| def summarize_captions(captions: list[str]) -> str: | |
| if not captions: | |
| return "" | |
| if len(captions) == 1: | |
| return captions[0] | |
| model, tokenizer = _get_summarizer_runtime() | |
| combined = " ".join(c.strip() for c in captions if c and c.strip()) | |
| if not combined: | |
| return "" | |
| try: | |
| inputs = tokenizer( | |
| combined, | |
| max_length=1024, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_length=512, | |
| min_length=100, | |
| length_penalty=2.0, | |
| num_beams=4, | |
| early_stopping=True, | |
| ) | |
| summary = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| except Exception as exc: | |
| raise AppError("Failed to summarize captions.", 500) from exc | |
| return _finalize_caption(summary, max_sentences=10) | |
| def generate_caption_text(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str: | |
| runtime_model, runtime_processor = _get_caption_runtime() | |
| model_device = str(next(runtime_model.parameters()).device) | |
| def _build_inputs(prompt: str): | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| text = runtime_processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| return runtime_processor( | |
| text=text, | |
| images=image, | |
| return_tensors="pt", | |
| truncation=False, | |
| max_length=PROCESSOR_MAX_LENGTH, | |
| ) | |
| try: | |
| inputs = _build_inputs(prompt) | |
| except Exception as exc: | |
| if "Mismatch in `image` token count" not in str(exc): | |
| raise AppError("Failed to preprocess image for captioning.", 422) from exc | |
| inputs = _build_inputs(CAPTION_RETRY_PROMPT) | |
| inputs = {k: v.to(model_device) for k, v in inputs.items()} | |
| try: | |
| with torch.no_grad(): | |
| outputs = runtime_model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_NEW_TOKENS, | |
| do_sample=False, | |
| num_beams=1, | |
| ) | |
| except Exception as exc: | |
| raise AppError("Caption generation failed.", 500) from exc | |
| decoded = runtime_processor.decode(outputs[0], skip_special_tokens=True).strip() | |
| caption = decoded.split("assistant")[-1].lstrip(":\n ").strip() | |
| return _finalize_caption(caption) | |
| def generate_caption_text_safe(image: Image.Image, prompt: str = CAPTION_PROMPT) -> str: | |
| global _caption_model, _caption_processor, _caption_force_cpu | |
| try: | |
| return generate_caption_text(image, prompt) | |
| except Exception as exc: | |
| msg = str(exc) | |
| if "CUDA error" not in msg and "device-side assert" not in msg: | |
| raise | |
| with _caption_lock: | |
| _caption_force_cpu = True | |
| _caption_model = None | |
| _caption_processor = None | |
| if torch.cuda.is_available(): | |
| try: | |
| torch.cuda.empty_cache() | |
| except Exception: | |
| pass | |
| return generate_caption_text(image, prompt) | |
| def insert_record(collection, payload: dict) -> str: | |
| try: | |
| result = collection.insert_one(payload) | |
| return str(result.inserted_id) | |
| except PyMongoError as exc: | |
| raise AppError("MongoDB insert failed.", 503) from exc | |
| async def _parse_images(request: Request) -> list[tuple[str, Image.Image]]: | |
| try: | |
| form = await request.form() | |
| except Exception as exc: | |
| raise AppError("Invalid request payload.", 422) from exc | |
| uploads: list[UploadFile | StarletteUploadFile] = [] | |
| for key in ("files", "files[]", "file"): | |
| for value in form.getlist(key): | |
| if isinstance(value, (UploadFile, StarletteUploadFile)): | |
| uploads.append(value) | |
| # Fallback for clients that send non-standard multipart keys. | |
| if not uploads: | |
| for _, value in form.multi_items(): | |
| if isinstance(value, (UploadFile, StarletteUploadFile)): | |
| uploads.append(value) | |
| if not uploads: | |
| raise AppError("At least one image is required.", 400) | |
| if len(uploads) > MAX_IMAGES: | |
| raise AppError("You can upload a maximum of 5 images.", 400) | |
| parsed_images = [] | |
| for i, upload in enumerate(uploads): | |
| if upload.content_type and not upload.content_type.startswith("image/"): | |
| raise AppError("All uploaded files must be images.", 400) | |
| file_bytes = await upload.read() | |
| if not file_bytes: | |
| raise AppError("One of the uploaded images is empty.", 400) | |
| try: | |
| image = Image.open(io.BytesIO(file_bytes)).convert("RGB") | |
| except UnidentifiedImageError as exc: | |
| raise AppError("One of the uploaded files is not a valid image.", 400) from exc | |
| except OSError as exc: | |
| raise AppError("Unable to read one of the uploaded images.", 400) from exc | |
| filename = upload.filename or f"image_{i+1}" | |
| parsed_images.append((filename, image)) | |
| return parsed_images | |
| async def generate_caption_summary(request: Request): | |
| _ensure_db_ready() | |
| images = await _parse_images(request) | |
| image_captions = [] | |
| for filename, image in images: | |
| caption = generate_caption_text_safe(image) | |
| if not caption: | |
| raise AppError("Caption generation produced empty text.", 500) | |
| image_captions.append({"filename": filename, "caption": caption}) | |
| caption_texts = [x["caption"] for x in image_captions] | |
| caption = summarize_captions(caption_texts) | |
| if not caption: | |
| raise AppError("Caption summarization produced empty text.", 500) | |
| mongo_payload = { | |
| "caption": caption, | |
| "source_filenames": [item["filename"] for item in image_captions], | |
| "image_captions": image_captions, | |
| "images_count": len(image_captions), | |
| "is_summarized": len(image_captions) > 1, | |
| "created_at": datetime.now(timezone.utc), | |
| } | |
| audio_file_id = insert_record(caption_collection, mongo_payload) | |
| response_data = {**mongo_payload, "audio_file_id": audio_file_id} | |
| response_data.pop("_id", None) # Remove ObjectId as it is not JSON serializable | |
| response_data["created_at"] = response_data["created_at"].isoformat() | |
| return ok("Caption generated successfully.", response_data) | |
| async def generate_caption_collage(request: Request): | |
| _ensure_db_ready() | |
| images = await _parse_images(request) | |
| # Create collage (horizontal strip, resized to height 512 for consistency) | |
| resized_images = [] | |
| target_height = 512 | |
| for _, img in images: | |
| aspect_ratio = img.width / img.height | |
| new_width = int(target_height * aspect_ratio) | |
| resized_images.append(img.resize((new_width, target_height), Image.Resampling.LANCZOS)) | |
| total_width = sum(img.width for img in resized_images) | |
| collage = Image.new("RGB", (total_width, target_height)) | |
| x_offset = 0 | |
| for img in resized_images: | |
| collage.paste(img, (x_offset, 0)) | |
| x_offset += img.width | |
| caption = generate_caption_text_safe(collage) | |
| if not caption: | |
| raise AppError("Collage caption generation produced empty text.", 500) | |
| # For database storage, we list source filenames but the 'image_captions' | |
| # will just contain the single collage caption to avoid confusion. | |
| source_filenames = [fname for fname, _ in images] | |
| mongo_payload = { | |
| "caption": caption, | |
| "source_filenames": source_filenames, | |
| "image_captions": [{"filename": "collage", "caption": caption}], | |
| "images_count": len(images), | |
| "is_summarized": False, # It's a direct caption of a collage | |
| "created_at": datetime.now(timezone.utc), | |
| } | |
| audio_file_id = insert_record(caption_collection, mongo_payload) | |
| response_data = {**mongo_payload, "audio_file_id": audio_file_id} | |
| response_data.pop("_id", None) | |
| response_data["created_at"] = response_data["created_at"].isoformat() | |
| return ok("Collage caption generated successfully.", response_data) | |
| async def generate_caption_context(request: Request): | |
| _ensure_db_ready() | |
| images = await _parse_images(request) | |
| image_captions = [] | |
| previous_context = "" | |
| for i, (filename, image) in enumerate(images): | |
| prompt = CAPTION_PROMPT | |
| if i > 0 and previous_context: | |
| prompt = f"Context from previous image: {previous_context}. {CAPTION_PROMPT}" | |
| caption = generate_caption_text_safe(image, prompt=prompt) | |
| if not caption: | |
| caption = "No caption generated." | |
| image_captions.append({"filename": filename, "caption": caption}) | |
| previous_context = caption | |
| # Combine captions for the main 'caption' field | |
| full_text = " ".join([ic["caption"] for ic in image_captions]) | |
| mongo_payload = { | |
| "caption": full_text, | |
| "source_filenames": [fname for fname, _ in images], | |
| "image_captions": image_captions, | |
| "images_count": len(images), | |
| "is_summarized": False, | |
| "created_at": datetime.now(timezone.utc), | |
| } | |
| audio_file_id = insert_record(caption_collection, mongo_payload) | |
| response_data = {**mongo_payload, "audio_file_id": audio_file_id} | |
| response_data.pop("_id", None) | |
| response_data["created_at"] = response_data["created_at"].isoformat() | |
| return ok("Contextual captions generated successfully.", response_data) | |
| # import io | |
| # import logging | |
| # import os | |
| # import re | |
| # import threading | |
| # from datetime import datetime, timezone | |
| # # Avoid invalid OMP setting from runtime environment (e.g. empty/non-numeric). | |
| # _omp_threads = os.getenv("OMP_NUM_THREADS", "").strip() | |
| # if not _omp_threads.isdigit() or int(_omp_threads) < 1: | |
| # os.environ["OMP_NUM_THREADS"] = "8" | |
| # import torch | |
| # from dotenv import load_dotenv | |
| # from fastapi import FastAPI, Request, UploadFile | |
| # from fastapi.exceptions import RequestValidationError | |
| # from fastapi.responses import JSONResponse | |
| # from PIL import Image, UnidentifiedImageError | |
| # from pymongo import MongoClient | |
| # from pymongo.errors import PyMongoError, ServerSelectionTimeoutError | |
| # from starlette.datastructures import UploadFile as StarletteUploadFile | |
| # from transformers import ( | |
| # AutoModelForImageTextToText, | |
| # AutoModelForSeq2SeqLM, | |
| # AutoProcessor, | |
| # AutoTokenizer, | |
| # ) | |
| # load_dotenv() | |
| # CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "vidhi0405/Qwen_I2T") | |
| # SUMMARIZER_MODEL_ID = os.getenv("SUMMARIZER_MODEL_ID", "facebook/bart-large-cnn") | |
| # DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu") | |
| # DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32 | |
| # MAX_NEW_TOKENS = 120 | |
| # MAX_IMAGES = 5 | |
| # MONGO_URI = (os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or "").strip().strip('"').strip("'") | |
| # MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech") | |
| # CAPTION_PROMPT = ( | |
| # "Act as a professional news reporter delivering a live on-scene report in real time. " | |
| # "Speak naturally, as if you are addressing viewers who are watching this unfold right now. " | |
| # "Describe the scene in 3 to 4 complete, vivid sentences. " | |
| # "Mention what is happening, the surrounding environment, and the overall mood, " | |
| # "and convey the urgency or emotion of the moment when appropriate." | |
| # ) | |
| # CAPTION_RETRY_PROMPT = ( | |
| # "Describe this image in 2 to 3 complete sentences. " | |
| # "Mention the main subject, action, environment, and mood." | |
| # ) | |
| # CAPTION_MIN_SENTENCES = 3 | |
| # CAPTION_MAX_SENTENCES = 4 | |
| # PROCESSOR_MAX_LENGTH = 8192 | |
| # logger = logging.getLogger(__name__) | |
| # def ok(message: str, data): | |
| # return JSONResponse( | |
| # status_code=200, | |
| # content={"success": True, "message": message, "data": data}, | |
| # ) | |
| # def fail(message: str, status_code: int = 400): | |
| # return JSONResponse( | |
| # status_code=status_code, | |
| # content={"success": False, "message": message, "data": None}, | |
| # ) | |
| # class AppError(Exception): | |
| # def __init__(self, message: str, status_code: int = 400): | |
| # super().__init__(message) | |
| # self.message = message | |
| # self.status_code = status_code | |
| # torch.set_num_threads(8) | |
| # _caption_model = None | |
| # _caption_processor = None | |
| # _caption_lock = threading.Lock() | |
| # _caption_force_cpu = False | |
| # _summarizer_model = None | |
| # _summarizer_tokenizer = None | |
| # _summarizer_lock = threading.Lock() | |
| # app = FastAPI(title="Image to Text API") | |
| # mongo_client = None | |
| # mongo_db = None | |
| # caption_collection = None | |
| # db_init_error = None | |
| # if not MONGO_URI: | |
| # db_init_error = "MONGO_URI (or MONGODB_URI) is not set." | |
| # else: | |
| # try: | |
| # mongo_client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000) | |
| # mongo_client.admin.command("ping") | |
| # mongo_db = mongo_client[MONGO_DB_NAME] | |
| # caption_collection = mongo_db["captions"] | |
| # except ServerSelectionTimeoutError: | |
| # db_init_error = "Unable to connect to MongoDB (timeout)." | |
| # except PyMongoError as exc: | |
| # db_init_error = "Unable to initialize MongoDB: {}".format(exc) | |
| # @app.get("/") | |
| # def root(): | |
| # return { | |
| # "success": True, | |
| # "message": "Use POST /generate-caption with form-data key 'file' or 'files' (up to 5 images).", | |
| # "data": None, | |
| # } | |
| # @app.get("/health") | |
| # def health(): | |
| # if db_init_error: | |
| # return { | |
| # "success": False, | |
| # "message": db_init_error, | |
| # "data": { | |
| # "caption_model_id": CAPTION_MODEL_ID, | |
| # "summarizer_model_id": SUMMARIZER_MODEL_ID, | |
| # }, | |
| # } | |
| # return { | |
| # "success": True, | |
| # "message": "ok", | |
| # "data": { | |
| # "caption_model_id": CAPTION_MODEL_ID, | |
| # "summarizer_model_id": SUMMARIZER_MODEL_ID, | |
| # }, | |
| # } | |
| # @app.on_event("startup") | |
| # async def preload_runtime_models(): | |
| # if os.getenv("PRELOAD_MODELS", "1").strip().lower() in {"0", "false", "no"}: | |
| # logger.info("Model preloading disabled via PRELOAD_MODELS.") | |
| # return | |
| # try: | |
| # _get_caption_runtime() | |
| # logger.info("Caption model preloaded successfully.") | |
| # except Exception as exc: | |
| # logger.warning("Caption model preload failed: %s", exc) | |
| # try: | |
| # _get_summarizer_runtime() | |
| # logger.info("Summarizer model preloaded successfully.") | |
| # except Exception as exc: | |
| # logger.warning("Summarizer model preload failed: %s", exc) | |
| # @app.exception_handler(AppError) | |
| # async def app_error_handler(_, exc: AppError): | |
| # return fail(exc.message, exc.status_code) | |
| # @app.exception_handler(RequestValidationError) | |
| # async def validation_error_handler(_, exc: RequestValidationError): | |
| # return fail("Invalid request payload.", 422) | |
| # @app.exception_handler(Exception) | |
| # async def unhandled_error_handler(_, exc: Exception): | |
| # logger.exception("Unhandled server error: %s", exc) | |
| # return fail("Internal server error.", 500) | |
| # def _ensure_db_ready(): | |
| # if db_init_error: | |
| # raise AppError(db_init_error, 503) | |
| # def _finalize_caption(raw_text: str, max_sentences: int = CAPTION_MAX_SENTENCES) -> str: | |
| # text = " ".join(raw_text.split()).strip() | |
| # if not text: | |
| # return "" | |
| # sentences = re.findall(r"[^.!?]+[.!?]", text) | |
| # sentences = [s.strip() for s in sentences if s.strip()] | |
| # if len(sentences) >= CAPTION_MIN_SENTENCES: | |
| # return " ".join(sentences[:max_sentences]).strip() | |
| # if text and text[-1] not in ".!?": | |
| # text = re.sub(r"[,:;\-]\s*[^,:;\-]*$", "", text).strip() | |
| # return text | |
| # def _get_caption_runtime(): | |
| # global _caption_model, _caption_processor, _caption_force_cpu | |
| # if _caption_model is not None and _caption_processor is not None: | |
| # return _caption_model, _caption_processor | |
| # with _caption_lock: | |
| # if _caption_model is None or _caption_processor is None: | |
| # device = "cpu" if _caption_force_cpu else DEVICE | |
| # dtype = torch.float32 if device == "cpu" else DTYPE | |
| # try: | |
| # loaded_model = AutoModelForImageTextToText.from_pretrained( | |
| # CAPTION_MODEL_ID, | |
| # trust_remote_code=True, | |
| # torch_dtype=dtype, | |
| # low_cpu_mem_usage=True, | |
| # ).to(device) | |
| # loaded_processor = AutoProcessor.from_pretrained( | |
| # CAPTION_MODEL_ID, | |
| # trust_remote_code=True, | |
| # ) | |
| # except Exception as exc: | |
| # raise AppError("Failed to load caption model.", 503) from exc | |
| # loaded_model.eval() | |
| # _caption_model = loaded_model | |
| # _caption_processor = loaded_processor | |
| # return _caption_model, _caption_processor | |
| # def _get_summarizer_runtime(): | |
| # global _summarizer_model, _summarizer_tokenizer | |
| # if _summarizer_model is not None and _summarizer_tokenizer is not None: | |
| # return _summarizer_model, _summarizer_tokenizer | |
| # with _summarizer_lock: | |
| # if _summarizer_model is None or _summarizer_tokenizer is None: | |
| # try: | |
| # tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL_ID) | |
| # model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL_ID, torch_dtype=DTYPE).to(DEVICE) | |
| # except Exception as exc: | |
| # raise AppError("Failed to load summarization model.", 503) from exc | |
| # model.eval() | |
| # _summarizer_tokenizer = tokenizer | |
| # _summarizer_model = model | |
| # return _summarizer_model, _summarizer_tokenizer | |
| # def summarize_captions(captions: list[str]) -> str: | |
| # if not captions: | |
| # return "" | |
| # if len(captions) == 1: | |
| # return captions[0] | |
| # model, tokenizer = _get_summarizer_runtime() | |
| # combined = " ".join(c.strip() for c in captions if c and c.strip()) | |
| # if not combined: | |
| # return "" | |
| # try: | |
| # inputs = tokenizer( | |
| # combined, | |
| # max_length=1024, | |
| # truncation=True, | |
| # return_tensors="pt", | |
| # ) | |
| # inputs = {k: v.to(DEVICE) for k, v in inputs.items()} | |
| # with torch.no_grad(): | |
| # output_ids = model.generate( | |
| # **inputs, | |
| # max_length=300, | |
| # min_length=50, | |
| # length_penalty=2.0, | |
| # num_beams=4, | |
| # early_stopping=True, | |
| # ) | |
| # summary = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() | |
| # except Exception as exc: | |
| # raise AppError("Failed to summarize captions.", 500) from exc | |
| # return _finalize_caption(summary, max_sentences=10) | |
| # def generate_caption_text(image: Image.Image) -> str: | |
| # runtime_model, runtime_processor = _get_caption_runtime() | |
| # model_device = str(next(runtime_model.parameters()).device) | |
| # def _build_inputs(prompt: str): | |
| # messages = [ | |
| # { | |
| # "role": "user", | |
| # "content": [ | |
| # {"type": "image"}, | |
| # {"type": "text", "text": prompt}, | |
| # ], | |
| # } | |
| # ] | |
| # text = runtime_processor.apply_chat_template( | |
| # messages, tokenize=False, add_generation_prompt=True | |
| # ) | |
| # return runtime_processor( | |
| # text=text, | |
| # images=image, | |
| # return_tensors="pt", | |
| # truncation=False, | |
| # max_length=PROCESSOR_MAX_LENGTH, | |
| # ) | |
| # try: | |
| # inputs = _build_inputs(CAPTION_PROMPT) | |
| # except Exception as exc: | |
| # if "Mismatch in `image` token count" not in str(exc): | |
| # raise AppError("Failed to preprocess image for captioning.", 422) from exc | |
| # inputs = _build_inputs(CAPTION_RETRY_PROMPT) | |
| # inputs = {k: v.to(model_device) for k, v in inputs.items()} | |
| # try: | |
| # with torch.no_grad(): | |
| # outputs = runtime_model.generate( | |
| # **inputs, | |
| # max_new_tokens=MAX_NEW_TOKENS, | |
| # do_sample=False, | |
| # num_beams=1, | |
| # ) | |
| # except Exception as exc: | |
| # raise AppError("Caption generation failed.", 500) from exc | |
| # decoded = runtime_processor.decode(outputs[0], skip_special_tokens=True).strip() | |
| # caption = decoded.split("assistant")[-1].lstrip(":\n ").strip() | |
| # return _finalize_caption(caption) | |
| # def generate_caption_text_safe(image: Image.Image) -> str: | |
| # global _caption_model, _caption_processor, _caption_force_cpu | |
| # try: | |
| # return generate_caption_text(image) | |
| # except Exception as exc: | |
| # msg = str(exc) | |
| # if "CUDA error" not in msg and "device-side assert" not in msg: | |
| # raise | |
| # with _caption_lock: | |
| # _caption_force_cpu = True | |
| # _caption_model = None | |
| # _caption_processor = None | |
| # if torch.cuda.is_available(): | |
| # try: | |
| # torch.cuda.empty_cache() | |
| # except Exception: | |
| # pass | |
| # return generate_caption_text(image) | |
| # def insert_record(collection, payload: dict) -> str: | |
| # try: | |
| # result = collection.insert_one(payload) | |
| # return str(result.inserted_id) | |
| # except PyMongoError as exc: | |
| # raise AppError("MongoDB insert failed.", 503) from exc | |
| # @app.post("/generate-caption") | |
| # async def generate_caption(request: Request): | |
| # _ensure_db_ready() | |
| # try: | |
| # form = await request.form() | |
| # except Exception as exc: | |
| # raise AppError("Invalid request payload.", 422) from exc | |
| # uploads: list[UploadFile | StarletteUploadFile] = [] | |
| # for key in ("files", "files[]", "file"): | |
| # for value in form.getlist(key): | |
| # if isinstance(value, (UploadFile, StarletteUploadFile)): | |
| # uploads.append(value) | |
| # # Fallback for clients that send non-standard multipart keys. | |
| # if not uploads: | |
| # for _, value in form.multi_items(): | |
| # if isinstance(value, (UploadFile, StarletteUploadFile)): | |
| # uploads.append(value) | |
| # if not uploads: | |
| # raise AppError("At least one image is required.", 400) | |
| # if len(uploads) > MAX_IMAGES: | |
| # raise AppError("You can upload a maximum of 5 images.", 400) | |
| # image_captions = [] | |
| # for upload in uploads: | |
| # if upload.content_type and not upload.content_type.startswith("image/"): | |
| # raise AppError("All uploaded files must be images.", 400) | |
| # file_bytes = await upload.read() | |
| # if not file_bytes: | |
| # raise AppError("One of the uploaded images is empty.", 400) | |
| # try: | |
| # image = Image.open(io.BytesIO(file_bytes)).convert("RGB") | |
| # except UnidentifiedImageError as exc: | |
| # raise AppError("One of the uploaded files is not a valid image.", 400) from exc | |
| # except OSError as exc: | |
| # raise AppError("Unable to read one of the uploaded images.", 400) from exc | |
| # caption = generate_caption_text_safe(image) | |
| # if not caption: | |
| # raise AppError("Caption generation produced empty text.", 500) | |
| # image_captions.append({"filename": upload.filename, "caption": caption}) | |
| # caption_texts = [x["caption"] for x in image_captions] | |
| # caption = summarize_captions(caption_texts) | |
| # if not caption: | |
| # raise AppError("Caption summarization produced empty text.", 500) | |
| # mongo_payload = { | |
| # "caption": caption, | |
| # "source_filenames": [item["filename"] for item in image_captions], | |
| # "image_captions": image_captions, | |
| # "images_count": len(image_captions), | |
| # "is_summarized": len(image_captions) > 1, | |
| # "created_at": datetime.now(timezone.utc), | |
| # } | |
| # audio_file_id = insert_record(caption_collection, mongo_payload) | |
| # response_data = {**mongo_payload, "audio_file_id": audio_file_id} | |
| # response_data.pop("_id", None) # Remove ObjectId as it is not JSON serializable | |
| # response_data["created_at"] = response_data["created_at"].isoformat() | |
| # return ok("Caption generated successfully.", response_data) | |