khushalcodiste commited on
Commit
afd6ed3
·
1 Parent(s): da2a069
Files changed (2) hide show
  1. src/model.py +3 -15
  2. src/server.py +7 -67
src/model.py CHANGED
@@ -15,17 +15,7 @@ MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256"))
15
  MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
16
  RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32"))
17
  NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3"))
18
-
19
- TASKS = {
20
- "caption": "<CAPTION>",
21
- "detailed_caption": "<DETAILED_CAPTION>",
22
- "more_detailed_caption": "<MORE_DETAILED_CAPTION>",
23
- "ocr": "<OCR>",
24
- "ocr_with_region": "<OCR_WITH_REGION>",
25
- "object_detection": "<OD>",
26
- "dense_region_caption": "<DENSE_REGION_CAPTION>",
27
- "region_proposal": "<REGION_PROPOSAL>",
28
- }
29
 
30
  _model = None
31
  _processor = None
@@ -78,13 +68,11 @@ def load_model() -> tuple[Any, Any]:
78
 
79
  def generate_caption(
80
  image_bytes: bytes,
81
- task: str = "caption",
82
  text_input: str | None = None,
83
  max_tokens: int = DEFAULT_MAX_TOKENS,
84
  ) -> dict[str, Any]:
85
  model, processor = load_model()
86
- prompt_task = TASKS.get(task, TASKS["caption"])
87
- prompt = f"{prompt_task} {text_input.strip()}" if text_input else prompt_task
88
 
89
  safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS)
90
  image = _prepare_image(image_bytes)
@@ -110,7 +98,7 @@ def generate_caption(
110
  try:
111
  parsed = post_process(
112
  generated_text,
113
- task=prompt_task,
114
  image_size=(image.width, image.height),
115
  )
116
  except Exception:
 
15
  MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
16
  RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32"))
17
  NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3"))
18
+ DEFAULT_PROMPT = os.getenv("DEFAULT_PROMPT", "<CAPTION>")
 
 
 
 
 
 
 
 
 
 
19
 
20
  _model = None
21
  _processor = None
 
68
 
69
  def generate_caption(
70
  image_bytes: bytes,
 
71
  text_input: str | None = None,
72
  max_tokens: int = DEFAULT_MAX_TOKENS,
73
  ) -> dict[str, Any]:
74
  model, processor = load_model()
75
+ prompt = f"{DEFAULT_PROMPT} {text_input.strip()}" if text_input else DEFAULT_PROMPT
 
76
 
77
  safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS)
78
  image = _prepare_image(image_bytes)
 
98
  try:
99
  parsed = post_process(
100
  generated_text,
101
+ task=DEFAULT_PROMPT,
102
  image_size=(image.width, image.height),
103
  )
104
  except Exception:
src/server.py CHANGED
@@ -5,13 +5,12 @@ from typing import Any
5
 
6
  from fastapi import FastAPI, File, Form, HTTPException, UploadFile
7
  from fastapi.middleware.cors import CORSMiddleware
8
- from fastapi.responses import HTMLResponse
9
 
10
- from .model import MODEL_ID, TASKS, DEFAULT_MAX_TOKENS, generate_caption, load_model
11
 
12
  app = FastAPI(
13
  title="img3txt - Florence-2 API",
14
- description="Generate captions, OCR, object detection and more from images using Florence-2.",
15
  version="1.0.0",
16
  )
17
 
@@ -29,84 +28,25 @@ def warmup_model() -> None:
29
  load_model()
30
 
31
 
32
- @app.get("/", response_class=HTMLResponse, include_in_schema=False)
33
- def root() -> str:
34
- return """<!DOCTYPE html>
35
- <html lang=\"en\"><head><meta charset=\"utf-8\">
36
- <meta name=\"viewport\" content=\"width=device-width,initial-scale=1\">
37
- <title>img3txt - Florence-2 Image Captioning API</title>
38
- <style>
39
- *{margin:0;padding:0;box-sizing:border-box}
40
- body{font-family:system-ui,sans-serif;background:#0f172a;color:#e2e8f0;display:flex;align-items:center;justify-content:center;min-height:100vh}
41
- .card{background:#1e293b;border-radius:16px;padding:2.5rem;max-width:520px;width:90%;text-align:center;box-shadow:0 25px 50px rgba(0,0,0,.4)}
42
- h1{font-size:1.8rem;margin-bottom:.5rem}
43
- .sub{color:#94a3b8;margin-bottom:1.5rem}
44
- .btn{display:inline-block;padding:.75rem 1.5rem;background:#3b82f6;color:#fff;border-radius:8px;text-decoration:none;font-weight:600;margin:.25rem}
45
- .btn:hover{background:#2563eb}
46
- .tasks{margin-top:1.5rem;text-align:left;background:#0f172a;border-radius:8px;padding:1rem}
47
- .tasks code{color:#38bdf8}
48
- </style></head><body>
49
- <div class=\"card\">
50
- <h1>img3txt</h1>
51
- <p class=\"sub\">Image captioning, OCR &amp; object detection powered by Florence-2</p>
52
- <a class=\"btn\" href=\"/docs\">Swagger UI</a>
53
- <a class=\"btn\" href=\"/health\">Health Check</a>
54
- <div class=\"tasks\">
55
- <p><strong>POST /caption</strong> with form fields:</p>
56
- <ul style=\"margin:.5rem 0 0 1.2rem;color:#94a3b8\">
57
- <li><code>file</code> - image (required)</li>
58
- <li><code>task</code> - caption, detailed_caption, more_detailed_caption, ocr, ocr_with_region, object_detection, dense_region_caption, region_proposal</li>
59
- <li><code>max_tokens</code> - default 64 (smaller = faster)</li>
60
- </ul>
61
- </div>
62
- </div></body></html>"""
63
 
64
 
65
  @app.get("/health")
66
  def health() -> dict[str, Any]:
67
- return {"status": "ok", "model": MODEL_ID, "tasks": list(TASKS.keys())}
68
 
69
 
70
- @app.post("/caption")
71
- async def caption(
72
  file: UploadFile = File(...),
73
- task: str = Form("caption"),
74
  text: str | None = Form(None),
75
  max_tokens: int = Form(DEFAULT_MAX_TOKENS),
76
  ) -> dict[str, Any]:
77
- if task not in TASKS:
78
- raise HTTPException(status_code=400, detail=f"Invalid task. Choose from: {', '.join(TASKS.keys())}")
79
-
80
  image_bytes = await file.read()
81
  if not image_bytes:
82
  raise HTTPException(status_code=400, detail="Empty file uploaded")
83
 
84
- result = generate_caption(image_bytes, task, text, max_tokens)
85
- return {"task": task, "result": result}
86
-
87
-
88
- @app.post("/caption/batch")
89
- async def caption_batch(
90
- files: list[UploadFile] = File(...),
91
- task: str = Form("caption"),
92
- text: str | None = Form(None),
93
- max_tokens: int = Form(DEFAULT_MAX_TOKENS),
94
- ) -> dict[str, Any]:
95
- if task not in TASKS:
96
- raise HTTPException(status_code=400, detail=f"Invalid task. Choose from: {', '.join(TASKS.keys())}")
97
-
98
- results: list[dict[str, Any]] = []
99
- for upload in files:
100
- image_bytes = await upload.read()
101
- if not image_bytes:
102
- continue
103
- result = generate_caption(image_bytes, task, text, max_tokens)
104
- results.append({"filename": upload.filename, "task": task, "result": result})
105
-
106
- if not results:
107
- raise HTTPException(status_code=400, detail="No files uploaded")
108
-
109
- return {"results": results}
110
 
111
 
112
  if __name__ == "__main__":
 
5
 
6
  from fastapi import FastAPI, File, Form, HTTPException, UploadFile
7
  from fastapi.middleware.cors import CORSMiddleware
 
8
 
9
+ from .model import MODEL_ID, DEFAULT_MAX_TOKENS, generate_caption, load_model
10
 
11
  app = FastAPI(
12
  title="img3txt - Florence-2 API",
13
+ description="Simple image-to-text endpoint powered by Florence-2-base.",
14
  version="1.0.0",
15
  )
16
 
 
28
  load_model()
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
 
33
  @app.get("/health")
34
  def health() -> dict[str, Any]:
35
+ return {"status": "ok", "model": MODEL_ID}
36
 
37
 
38
+ @app.post("/predict")
39
+ async def predict(
40
  file: UploadFile = File(...),
 
41
  text: str | None = Form(None),
42
  max_tokens: int = Form(DEFAULT_MAX_TOKENS),
43
  ) -> dict[str, Any]:
 
 
 
44
  image_bytes = await file.read()
45
  if not image_bytes:
46
  raise HTTPException(status_code=400, detail="Empty file uploaded")
47
 
48
+ result = generate_caption(image_bytes, text, max_tokens)
49
+ return {"result": result}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
 
52
  if __name__ == "__main__":