khushalcodiste commited on
Commit
49f8ccd
·
1 Parent(s): afd6ed3

fix: added

Browse files
Files changed (3) hide show
  1. README.md +2 -0
  2. src/model.py +31 -3
  3. src/server.py +4 -1
README.md CHANGED
@@ -11,3 +11,5 @@ pinned: false
11
  Image captioning API using `microsoft/Florence-2-base` with a Python FastAPI backend. Open `/docs` for Swagger UI.
12
 
13
  Speed tuning env vars: `DEFAULT_MAX_TOKENS` (default `64`), `MAX_IMAGE_SIDE` (default `896`), `MAX_MAX_TOKENS` (default `256`), `MODEL_ID` (default `microsoft/Florence-2-base`), `MODEL_REVISION` (pin to a commit SHA, e.g. `5ca5edf5bd017b9919c05d08aebef5e4c7ac3bac`).
 
 
 
11
  Image captioning API using `microsoft/Florence-2-base` with a Python FastAPI backend. Open `/docs` for Swagger UI.
12
 
13
  Speed tuning env vars: `DEFAULT_MAX_TOKENS` (default `64`), `MAX_IMAGE_SIDE` (default `896`), `MAX_MAX_TOKENS` (default `256`), `MODEL_ID` (default `microsoft/Florence-2-base`), `MODEL_REVISION` (pin to a commit SHA, e.g. `5ca5edf5bd017b9919c05d08aebef5e4c7ac3bac`).
14
+
15
+ `POST /predict` form field `text` is the full Florence-2 task prompt. For standard captioning use `<CAPTION>` only (or omit `text` to use the default). Do not append extra words to `<CAPTION>`.
src/model.py CHANGED
@@ -1,6 +1,7 @@
1
  from __future__ import annotations
2
 
3
  import os
 
4
  from io import BytesIO
5
  from typing import Any
6
 
@@ -16,6 +17,7 @@ MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
16
  RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32"))
17
  NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3"))
18
  DEFAULT_PROMPT = os.getenv("DEFAULT_PROMPT", "<CAPTION>")
 
19
 
20
  _model = None
21
  _processor = None
@@ -66,18 +68,44 @@ def load_model() -> tuple[Any, Any]:
66
  return _model, _processor
67
 
68
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  def generate_caption(
70
  image_bytes: bytes,
71
  text_input: str | None = None,
72
  max_tokens: int = DEFAULT_MAX_TOKENS,
73
  ) -> dict[str, Any]:
74
  model, processor = load_model()
75
- prompt = f"{DEFAULT_PROMPT} {text_input.strip()}" if text_input else DEFAULT_PROMPT
76
 
77
  safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS)
78
  image = _prepare_image(image_bytes)
79
 
80
- inputs = processor(text=prompt, images=image, return_tensors="pt")
 
 
 
 
 
 
81
  input_ids = inputs["input_ids"].to(_device)
82
  pixel_values = inputs["pixel_values"].to(_device, _dtype)
83
 
@@ -98,7 +126,7 @@ def generate_caption(
98
  try:
99
  parsed = post_process(
100
  generated_text,
101
- task=DEFAULT_PROMPT,
102
  image_size=(image.width, image.height),
103
  )
104
  except Exception:
 
1
  from __future__ import annotations
2
 
3
  import os
4
+ import re
5
  from io import BytesIO
6
  from typing import Any
7
 
 
17
  RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32"))
18
  NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3"))
19
  DEFAULT_PROMPT = os.getenv("DEFAULT_PROMPT", "<CAPTION>")
20
+ TASK_TOKEN_PATTERN = re.compile(r"^<[^>\s]+>")
21
 
22
  _model = None
23
  _processor = None
 
68
  return _model, _processor
69
 
70
 
71
+ def _build_prompt(text_input: str | None) -> str:
72
+ if text_input is None:
73
+ return DEFAULT_PROMPT
74
+
75
+ prompt = text_input.strip()
76
+ if not prompt:
77
+ return DEFAULT_PROMPT
78
+ if not prompt.startswith("<"):
79
+ raise ValueError(
80
+ "Invalid prompt in `text`: expected a Florence-2 task token like "
81
+ "'<CAPTION>' or '<CAPTION_TO_PHRASE_GROUNDING>phrase'."
82
+ )
83
+ return prompt
84
+
85
+
86
+ def _task_token_from_prompt(prompt: str) -> str:
87
+ match = TASK_TOKEN_PATTERN.match(prompt)
88
+ return match.group(0) if match else DEFAULT_PROMPT
89
+
90
+
91
  def generate_caption(
92
  image_bytes: bytes,
93
  text_input: str | None = None,
94
  max_tokens: int = DEFAULT_MAX_TOKENS,
95
  ) -> dict[str, Any]:
96
  model, processor = load_model()
97
+ prompt = _build_prompt(text_input)
98
 
99
  safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS)
100
  image = _prepare_image(image_bytes)
101
 
102
+ try:
103
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
104
+ except AssertionError as exc:
105
+ raise ValueError(
106
+ "Invalid Florence-2 task format in `text`. For plain captioning, use only "
107
+ "'<CAPTION>' with no extra words."
108
+ ) from exc
109
  input_ids = inputs["input_ids"].to(_device)
110
  pixel_values = inputs["pixel_values"].to(_device, _dtype)
111
 
 
126
  try:
127
  parsed = post_process(
128
  generated_text,
129
+ task=_task_token_from_prompt(prompt),
130
  image_size=(image.width, image.height),
131
  )
132
  except Exception:
src/server.py CHANGED
@@ -45,7 +45,10 @@ async def predict(
45
  if not image_bytes:
46
  raise HTTPException(status_code=400, detail="Empty file uploaded")
47
 
48
- result = generate_caption(image_bytes, text, max_tokens)
 
 
 
49
  return {"result": result}
50
 
51
 
 
45
  if not image_bytes:
46
  raise HTTPException(status_code=400, detail="Empty file uploaded")
47
 
48
+ try:
49
+ result = generate_caption(image_bytes, text, max_tokens)
50
+ except ValueError as exc:
51
+ raise HTTPException(status_code=400, detail=str(exc)) from exc
52
  return {"result": result}
53
 
54