khushalcodiste commited on
Commit
210def2
·
1 Parent(s): 641b32e

fix: added

Browse files
Files changed (4) hide show
  1. README.md +1 -1
  2. docker-compose.yml +2 -0
  3. requirements.txt +2 -0
  4. src/model.py +34 -9
README.md CHANGED
@@ -10,4 +10,4 @@ pinned: false
10
 
11
  Image captioning API using `microsoft/Florence-2-base` with a Python FastAPI backend. Open `/docs` for Swagger UI.
12
 
13
- Speed tuning env vars: `DEFAULT_MAX_TOKENS` (default `64`), `MAX_IMAGE_SIDE` (default `896`), `MAX_MAX_TOKENS` (default `256`), `MODEL_ID` (default `microsoft/Florence-2-base`).
 
10
 
11
  Image captioning API using `microsoft/Florence-2-base` with a Python FastAPI backend. Open `/docs` for Swagger UI.
12
 
13
+ Speed tuning env vars: `DEFAULT_MAX_TOKENS` (default `64`), `MAX_IMAGE_SIDE` (default `896`), `MAX_MAX_TOKENS` (default `256`), `MODEL_ID` (default `microsoft/Florence-2-base`), `MODEL_REVISION` (optional commit SHA to pin remote model code).
docker-compose.yml CHANGED
@@ -9,4 +9,6 @@ services:
9
  - MAX_IMAGE_SIDE=896
10
  - MAX_MAX_TOKENS=256
11
  - MODEL_ID=microsoft/Florence-2-base
 
 
12
  restart: unless-stopped
 
9
  - MAX_IMAGE_SIDE=896
10
  - MAX_MAX_TOKENS=256
11
  - MODEL_ID=microsoft/Florence-2-base
12
+ # Optional: pin to a specific commit SHA from huggingface.co/microsoft/Florence-2-base
13
+ # - MODEL_REVISION=<commit_sha>
14
  restart: unless-stopped
requirements.txt CHANGED
@@ -4,3 +4,5 @@ transformers==4.55.4
4
  torch==2.8.0
5
  pillow==11.3.0
6
  python-multipart==0.0.20
 
 
 
4
  torch==2.8.0
5
  pillow==11.3.0
6
  python-multipart==0.0.20
7
+ einops==0.8.1
8
+ timm==1.0.19
src/model.py CHANGED
@@ -9,9 +9,12 @@ from PIL import Image
9
  from transformers import AutoModelForCausalLM, AutoProcessor
10
 
11
  MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base")
 
12
  DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "64"))
13
  MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256"))
14
  MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
 
 
15
 
16
  TASKS = {
17
  "caption": "<CAPTION>",
@@ -26,8 +29,8 @@ TASKS = {
26
 
27
  _model = None
28
  _processor = None
29
- _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
- _dtype = torch.float16 if _device.type == "cuda" else torch.float32
31
 
32
 
33
  def _prepare_image(image_bytes: bytes) -> Image.Image:
@@ -36,19 +39,37 @@ def _prepare_image(image_bytes: bytes) -> Image.Image:
36
  if width <= MAX_IMAGE_SIDE and height <= MAX_IMAGE_SIDE:
37
  return image
38
 
39
- ratio = min(MAX_IMAGE_SIDE / width, MAX_IMAGE_SIDE / height)
40
- new_size = (max(1, int(width * ratio)), max(1, int(height * ratio)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  return image.resize(new_size, Image.Resampling.LANCZOS)
42
 
43
 
44
  def load_model() -> tuple[Any, Any]:
45
  global _model, _processor
46
  if _model is None or _processor is None:
47
- _processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
 
 
 
 
48
  _model = AutoModelForCausalLM.from_pretrained(
49
  MODEL_ID,
50
- trust_remote_code=True,
51
  torch_dtype=_dtype,
 
52
  ).to(_device)
53
  _model.eval()
54
  return _model, _processor
@@ -77,16 +98,20 @@ def generate_caption(
77
  pixel_values=pixel_values,
78
  do_sample=False,
79
  max_new_tokens=safe_max_tokens,
80
- num_beams=1,
81
  )
82
 
83
- generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
84
 
85
  parsed = None
86
  post_process = getattr(processor, "post_process_generation", None)
87
  if callable(post_process):
88
  try:
89
- parsed = post_process(generated_text, task=prompt_task, image_size=image.size)
 
 
 
 
90
  except Exception:
91
  parsed = None
92
 
 
9
  from transformers import AutoModelForCausalLM, AutoProcessor
10
 
11
  MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base")
12
+ MODEL_REVISION = os.getenv("MODEL_REVISION")
13
  DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "64"))
14
  MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256"))
15
  MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
16
+ RESIZE_MULTIPLE = int(os.getenv("RESIZE_MULTIPLE", "32"))
17
+ NUM_BEAMS = int(os.getenv("NUM_BEAMS", "3"))
18
 
19
  TASKS = {
20
  "caption": "<CAPTION>",
 
29
 
30
  _model = None
31
  _processor = None
32
+ _device = torch.device("cpu")
33
+ _dtype = torch.float32
34
 
35
 
36
  def _prepare_image(image_bytes: bytes) -> Image.Image:
 
39
  if width <= MAX_IMAGE_SIDE and height <= MAX_IMAGE_SIDE:
40
  return image
41
 
42
+ if width >= height:
43
+ # Landscape: cap width, preserve aspect ratio.
44
+ ratio = MAX_IMAGE_SIDE / width
45
+ else:
46
+ # Portrait: cap height, preserve aspect ratio.
47
+ ratio = MAX_IMAGE_SIDE / height
48
+
49
+ new_w = max(1, int(width * ratio))
50
+ new_h = max(1, int(height * ratio))
51
+
52
+ # Align dimensions to improve tensor-core friendly shapes.
53
+ if RESIZE_MULTIPLE > 1:
54
+ new_w = max(RESIZE_MULTIPLE, (new_w // RESIZE_MULTIPLE) * RESIZE_MULTIPLE)
55
+ new_h = max(RESIZE_MULTIPLE, (new_h // RESIZE_MULTIPLE) * RESIZE_MULTIPLE)
56
+
57
+ new_size = (new_w, new_h)
58
  return image.resize(new_size, Image.Resampling.LANCZOS)
59
 
60
 
61
  def load_model() -> tuple[Any, Any]:
62
  global _model, _processor
63
  if _model is None or _processor is None:
64
+ pretrained_kwargs: dict[str, Any] = {"trust_remote_code": True}
65
+ if MODEL_REVISION:
66
+ pretrained_kwargs["revision"] = MODEL_REVISION
67
+
68
+ _processor = AutoProcessor.from_pretrained(MODEL_ID, **pretrained_kwargs)
69
  _model = AutoModelForCausalLM.from_pretrained(
70
  MODEL_ID,
 
71
  torch_dtype=_dtype,
72
+ **pretrained_kwargs,
73
  ).to(_device)
74
  _model.eval()
75
  return _model, _processor
 
98
  pixel_values=pixel_values,
99
  do_sample=False,
100
  max_new_tokens=safe_max_tokens,
101
+ num_beams=max(1, NUM_BEAMS),
102
  )
103
 
104
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0].strip()
105
 
106
  parsed = None
107
  post_process = getattr(processor, "post_process_generation", None)
108
  if callable(post_process):
109
  try:
110
+ parsed = post_process(
111
+ generated_text,
112
+ task=prompt_task,
113
+ image_size=(image.width, image.height),
114
+ )
115
  except Exception:
116
  parsed = None
117