vidhi0405 commited on
Commit
7014644
·
1 Parent(s): 0ffe62a

only for Image to Text

Browse files
Files changed (2) hide show
  1. app.py +158 -6
  2. requirements.txt +1 -0
app.py CHANGED
@@ -3,6 +3,7 @@ import logging
3
  import os
4
  import re
5
  import threading
 
6
 
7
  # Avoid invalid OMP setting from runtime environment (e.g. empty/non-numeric).
8
  _omp_threads = os.getenv("OMP_NUM_THREADS", "").strip()
@@ -15,16 +16,26 @@ from fastapi import FastAPI, File, UploadFile
15
  from fastapi.exceptions import RequestValidationError
16
  from fastapi.responses import JSONResponse
17
  from PIL import Image, UnidentifiedImageError
18
- from transformers import AutoModelForImageTextToText, AutoProcessor
 
 
 
 
 
 
 
19
 
20
 
21
  load_dotenv()
22
 
23
  CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "vidhi0405/Qwen_I2T")
 
24
  DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
25
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
26
  MAX_NEW_TOKENS = 120
27
  MAX_IMAGES = 5
 
 
28
 
29
  CAPTION_PROMPT = (
30
  "Act as a professional news reporter delivering a live on-scene report in real time. "
@@ -70,9 +81,30 @@ _caption_model = None
70
  _caption_processor = None
71
  _caption_lock = threading.Lock()
72
  _caption_force_cpu = False
 
 
 
73
 
74
  app = FastAPI(title="Image to Text API")
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  @app.get("/")
78
  def root():
@@ -85,7 +117,40 @@ def root():
85
 
86
  @app.get("/health")
87
  def health():
88
- return {"success": True, "message": "ok", "data": {"caption_model_id": CAPTION_MODEL_ID}}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
 
91
  @app.exception_handler(AppError)
@@ -104,6 +169,11 @@ async def unhandled_error_handler(_, exc: Exception):
104
  return fail("Internal server error.", 500)
105
 
106
 
 
 
 
 
 
107
  def _finalize_caption(raw_text: str) -> str:
108
  text = " ".join(raw_text.split()).strip()
109
  if not text:
@@ -149,6 +219,59 @@ def _get_caption_runtime():
149
  return _caption_model, _caption_processor
150
 
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
  def generate_caption_text(image: Image.Image) -> str:
153
  runtime_model, runtime_processor = _get_caption_runtime()
154
  model_device = str(next(runtime_model.parameters()).device)
@@ -222,11 +345,21 @@ def generate_caption_text_safe(image: Image.Image) -> str:
222
  return generate_caption_text(image)
223
 
224
 
 
 
 
 
 
 
 
 
225
  @app.post("/generate-caption")
226
  async def generate_caption(
227
  file: UploadFile | None = File(default=None),
228
  files: list[UploadFile] | None = File(default=None),
229
  ):
 
 
230
  uploads = []
231
  if files:
232
  uploads.extend(files)
@@ -259,11 +392,30 @@ async def generate_caption(
259
 
260
  image_captions.append({"filename": upload.filename, "caption": caption})
261
 
262
- return ok(
263
- "Caption generated successfully.",
 
 
 
 
 
264
  {
265
- "caption": image_captions[0]["caption"] if len(image_captions) == 1 else None,
266
- "individual_captions": image_captions,
 
267
  "images_count": len(image_captions),
 
 
268
  },
269
  )
 
 
 
 
 
 
 
 
 
 
 
 
3
  import os
4
  import re
5
  import threading
6
+ from datetime import datetime, timezone
7
 
8
  # Avoid invalid OMP setting from runtime environment (e.g. empty/non-numeric).
9
  _omp_threads = os.getenv("OMP_NUM_THREADS", "").strip()
 
16
  from fastapi.exceptions import RequestValidationError
17
  from fastapi.responses import JSONResponse
18
  from PIL import Image, UnidentifiedImageError
19
+ from pymongo import MongoClient
20
+ from pymongo.errors import PyMongoError, ServerSelectionTimeoutError
21
+ from transformers import (
22
+ AutoModelForImageTextToText,
23
+ AutoModelForSeq2SeqLM,
24
+ AutoProcessor,
25
+ AutoTokenizer,
26
+ )
27
 
28
 
29
  load_dotenv()
30
 
31
  CAPTION_MODEL_ID = os.getenv("CAPTION_MODEL_ID", "vidhi0405/Qwen_I2T")
32
+ SUMMARIZER_MODEL_ID = os.getenv("SUMMARIZER_MODEL_ID", "facebook/bart-large-cnn")
33
  DEVICE = os.getenv("DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
34
  DTYPE = torch.float16 if DEVICE == "cuda" else torch.float32
35
  MAX_NEW_TOKENS = 120
36
  MAX_IMAGES = 5
37
+ MONGO_URI = (os.getenv("MONGO_URI") or os.getenv("MONGODB_URI") or "").strip().strip('"').strip("'")
38
+ MONGO_DB_NAME = os.getenv("MONGO_DB_NAME", "image_to_speech")
39
 
40
  CAPTION_PROMPT = (
41
  "Act as a professional news reporter delivering a live on-scene report in real time. "
 
81
  _caption_processor = None
82
  _caption_lock = threading.Lock()
83
  _caption_force_cpu = False
84
+ _summarizer_model = None
85
+ _summarizer_tokenizer = None
86
+ _summarizer_lock = threading.Lock()
87
 
88
  app = FastAPI(title="Image to Text API")
89
 
90
+ mongo_client = None
91
+ mongo_db = None
92
+ caption_collection = None
93
+ db_init_error = None
94
+
95
+ if not MONGO_URI:
96
+ db_init_error = "MONGO_URI (or MONGODB_URI) is not set."
97
+ else:
98
+ try:
99
+ mongo_client = MongoClient(MONGO_URI, serverSelectionTimeoutMS=5000)
100
+ mongo_client.admin.command("ping")
101
+ mongo_db = mongo_client[MONGO_DB_NAME]
102
+ caption_collection = mongo_db["captions"]
103
+ except ServerSelectionTimeoutError:
104
+ db_init_error = "Unable to connect to MongoDB (timeout)."
105
+ except PyMongoError as exc:
106
+ db_init_error = "Unable to initialize MongoDB: {}".format(exc)
107
+
108
 
109
  @app.get("/")
110
  def root():
 
117
 
118
  @app.get("/health")
119
  def health():
120
+ if db_init_error:
121
+ return {
122
+ "success": False,
123
+ "message": db_init_error,
124
+ "data": {
125
+ "caption_model_id": CAPTION_MODEL_ID,
126
+ "summarizer_model_id": SUMMARIZER_MODEL_ID,
127
+ },
128
+ }
129
+ return {
130
+ "success": True,
131
+ "message": "ok",
132
+ "data": {
133
+ "caption_model_id": CAPTION_MODEL_ID,
134
+ "summarizer_model_id": SUMMARIZER_MODEL_ID,
135
+ },
136
+ }
137
+
138
+
139
+ @app.on_event("startup")
140
+ async def preload_runtime_models():
141
+ if os.getenv("PRELOAD_MODELS", "1").strip().lower() in {"0", "false", "no"}:
142
+ logger.info("Model preloading disabled via PRELOAD_MODELS.")
143
+ return
144
+ try:
145
+ _get_caption_runtime()
146
+ logger.info("Caption model preloaded successfully.")
147
+ except Exception as exc:
148
+ logger.warning("Caption model preload failed: %s", exc)
149
+ try:
150
+ _get_summarizer_runtime()
151
+ logger.info("Summarizer model preloaded successfully.")
152
+ except Exception as exc:
153
+ logger.warning("Summarizer model preload failed: %s", exc)
154
 
155
 
156
  @app.exception_handler(AppError)
 
169
  return fail("Internal server error.", 500)
170
 
171
 
172
+ def _ensure_db_ready():
173
+ if db_init_error:
174
+ raise AppError(db_init_error, 503)
175
+
176
+
177
  def _finalize_caption(raw_text: str) -> str:
178
  text = " ".join(raw_text.split()).strip()
179
  if not text:
 
219
  return _caption_model, _caption_processor
220
 
221
 
222
+ def _get_summarizer_runtime():
223
+ global _summarizer_model, _summarizer_tokenizer
224
+ if _summarizer_model is not None and _summarizer_tokenizer is not None:
225
+ return _summarizer_model, _summarizer_tokenizer
226
+
227
+ with _summarizer_lock:
228
+ if _summarizer_model is None or _summarizer_tokenizer is None:
229
+ try:
230
+ tokenizer = AutoTokenizer.from_pretrained(SUMMARIZER_MODEL_ID)
231
+ model = AutoModelForSeq2SeqLM.from_pretrained(SUMMARIZER_MODEL_ID)
232
+ except Exception as exc:
233
+ raise AppError("Failed to load summarization model.", 503) from exc
234
+ model.eval()
235
+ _summarizer_tokenizer = tokenizer
236
+ _summarizer_model = model
237
+
238
+ return _summarizer_model, _summarizer_tokenizer
239
+
240
+
241
+ def summarize_captions(captions: list[str]) -> str:
242
+ if not captions:
243
+ return ""
244
+ if len(captions) == 1:
245
+ return captions[0]
246
+
247
+ model, tokenizer = _get_summarizer_runtime()
248
+ combined = " ".join(c.strip() for c in captions if c and c.strip())
249
+ if not combined:
250
+ return ""
251
+
252
+ try:
253
+ inputs = tokenizer(
254
+ combined,
255
+ max_length=1024,
256
+ truncation=True,
257
+ return_tensors="pt",
258
+ )
259
+ with torch.no_grad():
260
+ output_ids = model.generate(
261
+ **inputs,
262
+ max_length=150,
263
+ min_length=40,
264
+ length_penalty=2.0,
265
+ num_beams=4,
266
+ early_stopping=True,
267
+ )
268
+ summary = tokenizer.decode(output_ids[0], skip_special_tokens=True).strip()
269
+ except Exception as exc:
270
+ raise AppError("Failed to summarize captions.", 500) from exc
271
+
272
+ return _finalize_caption(summary)
273
+
274
+
275
  def generate_caption_text(image: Image.Image) -> str:
276
  runtime_model, runtime_processor = _get_caption_runtime()
277
  model_device = str(next(runtime_model.parameters()).device)
 
345
  return generate_caption_text(image)
346
 
347
 
348
+ def insert_record(collection, payload: dict) -> str:
349
+ try:
350
+ result = collection.insert_one(payload)
351
+ return str(result.inserted_id)
352
+ except PyMongoError as exc:
353
+ raise AppError("MongoDB insert failed.", 503) from exc
354
+
355
+
356
  @app.post("/generate-caption")
357
  async def generate_caption(
358
  file: UploadFile | None = File(default=None),
359
  files: list[UploadFile] | None = File(default=None),
360
  ):
361
+ _ensure_db_ready()
362
+
363
  uploads = []
364
  if files:
365
  uploads.extend(files)
 
392
 
393
  image_captions.append({"filename": upload.filename, "caption": caption})
394
 
395
+ caption_texts = [x["caption"] for x in image_captions]
396
+ caption = summarize_captions(caption_texts)
397
+ if not caption:
398
+ raise AppError("Caption summarization produced empty text.", 500)
399
+
400
+ audio_file_id = insert_record(
401
+ caption_collection,
402
  {
403
+ "caption": caption,
404
+ "source_filenames": [item["filename"] for item in image_captions],
405
+ "image_captions": image_captions,
406
  "images_count": len(image_captions),
407
+ "is_summarized": len(image_captions) > 1,
408
+ "created_at": datetime.now(timezone.utc),
409
  },
410
  )
411
+
412
+ response_data = {
413
+ "audio_file_id": audio_file_id,
414
+ "caption": caption,
415
+ "images_count": len(image_captions),
416
+ }
417
+ if len(image_captions) > 1:
418
+ response_data["individual_captions"] = image_captions
419
+ response_data["summarized_caption"] = caption
420
+
421
+ return ok("Caption generated successfully.", response_data)
requirements.txt CHANGED
@@ -20,3 +20,4 @@ opencv-python==4.9.0.80
20
  tqdm==4.66.0
21
  requests==2.31.0
22
  python-dotenv==1.0.1
 
 
20
  tqdm==4.66.0
21
  requests==2.31.0
22
  python-dotenv==1.0.1
23
+ pymongo[srv]==4.8.0