khushalcodiste commited on
Commit
641b32e
·
1 Parent(s): b02d5c5

fix: added

Browse files
Files changed (11) hide show
  1. .dockerignore +3 -1
  2. .gitignore +3 -0
  3. Dockerfile +13 -15
  4. README.md +2 -2
  5. docker-compose.yml +1 -0
  6. package.json +0 -19
  7. requirements.txt +6 -0
  8. src/model.js +0 -106
  9. src/model.py +93 -0
  10. src/server.js +0 -238
  11. src/server.py +116 -0
.dockerignore CHANGED
@@ -1,3 +1,5 @@
1
- node_modules
2
  .git
3
  .env
 
 
 
 
 
1
  .git
2
  .env
3
+ __pycache__
4
+ *.pyc
5
+ .venv
.gitignore CHANGED
@@ -1 +1,4 @@
1
  token.txt
 
 
 
 
1
  token.txt
2
+ __pycache__/
3
+ *.pyc
4
+ .venv/
Dockerfile CHANGED
@@ -1,25 +1,23 @@
1
- FROM node:22-slim
2
 
3
- # sharp needs libvips
4
- RUN apt-get update && \
5
- apt-get install -y --no-install-recommends libvips-dev && \
6
- rm -rf /var/lib/apt/lists/*
7
 
8
  WORKDIR /app
9
 
10
- COPY package.json ./
11
- RUN npm install --omit=dev
12
-
13
- COPY src/ src/
14
 
15
- # Give node user ownership of everything (including node_modules/.cache)
16
- RUN chown -R node:node /app
17
 
18
- USER node
 
19
 
20
- # Download model at build time so container starts fast
21
- RUN node -e "import('./src/model.js').then(m => m.loadModel()).then(() => process.exit(0))"
22
 
23
  EXPOSE 7860
24
 
25
- CMD ["node", "src/server.js"]
 
1
+ FROM python:3.11-slim
2
 
3
+ ENV PYTHONDONTWRITEBYTECODE=1
4
+ ENV PYTHONUNBUFFERED=1
 
 
5
 
6
  WORKDIR /app
7
 
8
+ RUN apt-get update && \
9
+ apt-get install -y --no-install-recommends libgl1 libglib2.0-0 && \
10
+ rm -rf /var/lib/apt/lists/*
 
11
 
12
+ COPY requirements.txt ./
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
 
15
+ COPY src/ src/
16
+ COPY README.md ./
17
 
18
+ # Download model weights at build time so cold start is faster.
19
+ RUN python -c "from src.model import load_model; load_model()"
20
 
21
  EXPOSE 7860
22
 
23
+ CMD ["uvicorn", "src.server:app", "--host", "0.0.0.0", "--port", "7860"]
README.md CHANGED
@@ -8,6 +8,6 @@ app_port: 7860
8
  pinned: false
9
  ---
10
 
11
- Image captioning API using FastVLM (ONNX). Open `/docs` for Swagger UI.
12
 
13
- Speed tuning env vars: `DEFAULT_MAX_TOKENS` (default `64`), `MAX_IMAGE_SIDE` (default `896`), `MAX_MAX_TOKENS` (default `256`).
 
8
  pinned: false
9
  ---
10
 
11
+ Image captioning API using `microsoft/Florence-2-base` with a Python FastAPI backend. Open `/docs` for Swagger UI.
12
 
13
+ Speed tuning env vars: `DEFAULT_MAX_TOKENS` (default `64`), `MAX_IMAGE_SIDE` (default `896`), `MAX_MAX_TOKENS` (default `256`), `MODEL_ID` (default `microsoft/Florence-2-base`).
docker-compose.yml CHANGED
@@ -8,4 +8,5 @@ services:
8
  - DEFAULT_MAX_TOKENS=64
9
  - MAX_IMAGE_SIDE=896
10
  - MAX_MAX_TOKENS=256
 
11
  restart: unless-stopped
 
8
  - DEFAULT_MAX_TOKENS=64
9
  - MAX_IMAGE_SIDE=896
10
  - MAX_MAX_TOKENS=256
11
+ - MODEL_ID=microsoft/Florence-2-base
12
  restart: unless-stopped
package.json DELETED
@@ -1,19 +0,0 @@
1
- {
2
- "name": "img3txt",
3
- "version": "1.0.0",
4
- "description": "Image captioning API using FastVLM ONNX model",
5
- "type": "module",
6
- "scripts": {
7
- "start": "node src/server.js",
8
- "dev": "node --watch src/server.js"
9
- },
10
- "dependencies": {
11
- "@huggingface/transformers": "^3.4.1",
12
- "fastify": "^5.2.1",
13
- "@fastify/multipart": "^9.0.3",
14
- "@fastify/swagger": "^9.4.2",
15
- "@fastify/swagger-ui": "^5.2.1",
16
- "@fastify/cors": "^10.0.2",
17
- "sharp": "^0.33.5"
18
- }
19
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ fastapi==0.116.1
2
+ uvicorn[standard]==0.35.0
3
+ transformers==4.55.4
4
+ torch==2.8.0
5
+ pillow==11.3.0
6
+ python-multipart==0.0.20
src/model.js DELETED
@@ -1,106 +0,0 @@
1
- import {
2
- AutoModelForImageTextToText,
3
- AutoProcessor,
4
- RawImage,
5
- } from "@huggingface/transformers";
6
- import sharp from "sharp";
7
-
8
- const MODEL_ID = "onnx-community/FastVLM-0.5B-ONNX";
9
- const DEFAULT_MAX_TOKENS = parseInt(process.env.DEFAULT_MAX_TOKENS || "64", 10);
10
- const MAX_MAX_TOKENS = parseInt(process.env.MAX_MAX_TOKENS || "256", 10);
11
- const MAX_IMAGE_SIDE = parseInt(process.env.MAX_IMAGE_SIDE || "896", 10);
12
-
13
- let model = null;
14
- let processor = null;
15
-
16
- /** Supported task instructions for FastVLM */
17
- export const TASKS = {
18
- caption: "Describe this image.",
19
- detailed_caption: "Describe this image in detail.",
20
- more_detailed_caption:
21
- "Provide a very detailed description of this image.",
22
- ocr: "Extract all readable text from this image.",
23
- ocr_with_region:
24
- "Extract all readable text and include where it appears in the image.",
25
- object_detection: "List the visible objects in this image.",
26
- dense_region_caption:
27
- "Describe this image region by region with detailed observations.",
28
- region_proposal:
29
- "Propose important regions in this image and explain what each region contains.",
30
- };
31
-
32
- export async function loadModel() {
33
- if (!model) {
34
- console.log("Loading FastVLM model...");
35
- model = await AutoModelForImageTextToText.from_pretrained(MODEL_ID, {
36
- dtype: {
37
- embed_tokens: "fp16",
38
- vision_encoder: "q4",
39
- decoder_model_merged: "q4",
40
- },
41
- });
42
- processor = await AutoProcessor.from_pretrained(MODEL_ID);
43
- console.log("Model loaded.");
44
- }
45
- return { model, processor };
46
- }
47
-
48
- /**
49
- * Generate text from an image buffer.
50
- * @param {Buffer} imageBuffer - Raw image bytes
51
- * @param {string} task - One of the TASKS keys (default: "caption")
52
- * @param {string|null} textInput - Optional extra text input for the task
53
- * @param {number} maxTokens - Max new tokens to generate
54
- * @returns {Promise<object>} Generated result from FastVLM
55
- */
56
- export async function generateCaption(
57
- imageBuffer,
58
- task = "caption",
59
- textInput = null,
60
- maxTokens = DEFAULT_MAX_TOKENS
61
- ) {
62
- const { model: m, processor: p } = await loadModel();
63
-
64
- const safeMaxTokens = Number.isFinite(maxTokens)
65
- ? Math.min(Math.max(maxTokens, 8), MAX_MAX_TOKENS)
66
- : DEFAULT_MAX_TOKENS;
67
-
68
- // Downscale large uploads to reduce encoder latency.
69
- const metadata = await sharp(imageBuffer).metadata();
70
- let preparedBuffer = imageBuffer;
71
- if (
72
- metadata.width &&
73
- metadata.height &&
74
- (metadata.width > MAX_IMAGE_SIDE || metadata.height > MAX_IMAGE_SIDE)
75
- ) {
76
- preparedBuffer = await sharp(imageBuffer)
77
- .resize({
78
- width: MAX_IMAGE_SIDE,
79
- height: MAX_IMAGE_SIDE,
80
- fit: "inside",
81
- withoutEnlargement: true,
82
- })
83
- .toBuffer();
84
- }
85
-
86
- const image = await RawImage.fromBlob(new Blob([preparedBuffer]));
87
-
88
- const baseInstruction = TASKS[task] || TASKS.caption;
89
- const instruction = textInput
90
- ? `${baseInstruction}\nAdditional instruction: ${textInput}`
91
- : baseInstruction;
92
- const messages = [{ role: "user", content: `<image>${instruction}` }];
93
- const prompt = p.apply_chat_template(messages, { add_generation_prompt: true });
94
- const inputs = await p(image, prompt, { add_special_tokens: false });
95
-
96
- const generatedIds = await m.generate({
97
- ...inputs,
98
- do_sample: false,
99
- max_new_tokens: safeMaxTokens,
100
- });
101
-
102
- const generatedText = p.batch_decode(generatedIds, {
103
- skip_special_tokens: true,
104
- })[0];
105
- return { text: generatedText.trim() };
106
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from io import BytesIO
5
+ from typing import Any
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import AutoModelForCausalLM, AutoProcessor
10
+
11
+ MODEL_ID = os.getenv("MODEL_ID", "microsoft/Florence-2-base")
12
+ DEFAULT_MAX_TOKENS = int(os.getenv("DEFAULT_MAX_TOKENS", "64"))
13
+ MAX_MAX_TOKENS = int(os.getenv("MAX_MAX_TOKENS", "256"))
14
+ MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "896"))
15
+
16
+ TASKS = {
17
+ "caption": "<CAPTION>",
18
+ "detailed_caption": "<DETAILED_CAPTION>",
19
+ "more_detailed_caption": "<MORE_DETAILED_CAPTION>",
20
+ "ocr": "<OCR>",
21
+ "ocr_with_region": "<OCR_WITH_REGION>",
22
+ "object_detection": "<OD>",
23
+ "dense_region_caption": "<DENSE_REGION_CAPTION>",
24
+ "region_proposal": "<REGION_PROPOSAL>",
25
+ }
26
+
27
+ _model = None
28
+ _processor = None
29
+ _device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ _dtype = torch.float16 if _device.type == "cuda" else torch.float32
31
+
32
+
33
+ def _prepare_image(image_bytes: bytes) -> Image.Image:
34
+ image = Image.open(BytesIO(image_bytes)).convert("RGB")
35
+ width, height = image.size
36
+ if width <= MAX_IMAGE_SIDE and height <= MAX_IMAGE_SIDE:
37
+ return image
38
+
39
+ ratio = min(MAX_IMAGE_SIDE / width, MAX_IMAGE_SIDE / height)
40
+ new_size = (max(1, int(width * ratio)), max(1, int(height * ratio)))
41
+ return image.resize(new_size, Image.Resampling.LANCZOS)
42
+
43
+
44
+ def load_model() -> tuple[Any, Any]:
45
+ global _model, _processor
46
+ if _model is None or _processor is None:
47
+ _processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
48
+ _model = AutoModelForCausalLM.from_pretrained(
49
+ MODEL_ID,
50
+ trust_remote_code=True,
51
+ torch_dtype=_dtype,
52
+ ).to(_device)
53
+ _model.eval()
54
+ return _model, _processor
55
+
56
+
57
+ def generate_caption(
58
+ image_bytes: bytes,
59
+ task: str = "caption",
60
+ text_input: str | None = None,
61
+ max_tokens: int = DEFAULT_MAX_TOKENS,
62
+ ) -> dict[str, Any]:
63
+ model, processor = load_model()
64
+ prompt_task = TASKS.get(task, TASKS["caption"])
65
+ prompt = f"{prompt_task} {text_input.strip()}" if text_input else prompt_task
66
+
67
+ safe_max_tokens = min(max(int(max_tokens), 8), MAX_MAX_TOKENS)
68
+ image = _prepare_image(image_bytes)
69
+
70
+ inputs = processor(text=prompt, images=image, return_tensors="pt")
71
+ input_ids = inputs["input_ids"].to(_device)
72
+ pixel_values = inputs["pixel_values"].to(_device, _dtype)
73
+
74
+ with torch.inference_mode():
75
+ generated_ids = model.generate(
76
+ input_ids=input_ids,
77
+ pixel_values=pixel_values,
78
+ do_sample=False,
79
+ max_new_tokens=safe_max_tokens,
80
+ num_beams=1,
81
+ )
82
+
83
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
84
+
85
+ parsed = None
86
+ post_process = getattr(processor, "post_process_generation", None)
87
+ if callable(post_process):
88
+ try:
89
+ parsed = post_process(generated_text, task=prompt_task, image_size=image.size)
90
+ except Exception:
91
+ parsed = None
92
+
93
+ return {"text": generated_text, "parsed": parsed} if parsed else {"text": generated_text}
src/server.js DELETED
@@ -1,238 +0,0 @@
1
- import Fastify from "fastify";
2
- import multipart from "@fastify/multipart";
3
- import swagger from "@fastify/swagger";
4
- import swaggerUi from "@fastify/swagger-ui";
5
- import cors from "@fastify/cors";
6
- import { generateCaption, loadModel, TASKS } from "./model.js";
7
-
8
- const app = Fastify({ logger: true });
9
- const DEFAULT_MAX_TOKENS = parseInt(process.env.DEFAULT_MAX_TOKENS || "64", 10);
10
-
11
- // --- Plugins ---
12
- await app.register(cors);
13
- await app.register(multipart, { limits: { fileSize: 20 * 1024 * 1024 } });
14
-
15
- await app.register(swagger, {
16
- openapi: {
17
- info: {
18
- title: "img3txt — Image Captioning API",
19
- description:
20
- "Generate captions, OCR, object detection and more from images using FastVLM (ONNX).",
21
- version: "1.0.0",
22
- },
23
- servers: [{ url: "/" }],
24
- tags: [
25
- { name: "caption", description: "Image captioning endpoints" },
26
- { name: "health", description: "Health check" },
27
- ],
28
- },
29
- });
30
-
31
- await app.register(swaggerUi, {
32
- routePrefix: "/docs",
33
- uiConfig: { docExpansion: "list", deepLinking: true },
34
- });
35
-
36
- // --- Schemas ---
37
- const taskEnum = Object.keys(TASKS);
38
-
39
- const captionResponseSchema = {
40
- type: "object",
41
- properties: {
42
- task: { type: "string", example: "caption" },
43
- result: { type: "object", additionalProperties: true },
44
- },
45
- };
46
-
47
- const batchResponseSchema = {
48
- type: "object",
49
- properties: {
50
- results: {
51
- type: "array",
52
- items: {
53
- type: "object",
54
- properties: {
55
- filename: { type: "string" },
56
- task: { type: "string" },
57
- result: { type: "object", additionalProperties: true },
58
- },
59
- },
60
- },
61
- },
62
- };
63
-
64
- const errorSchema = {
65
- type: "object",
66
- properties: {
67
- error: { type: "string" },
68
- },
69
- };
70
-
71
- // --- Routes ---
72
-
73
- // Landing page — HF Spaces iframe shows this
74
- app.get(
75
- "/",
76
- { schema: { hide: true } },
77
- async (req, reply) => {
78
- reply.type("text/html").send(`<!DOCTYPE html>
79
- <html lang="en"><head><meta charset="utf-8">
80
- <meta name="viewport" content="width=device-width,initial-scale=1">
81
- <title>img3txt — FastVLM Image Captioning API</title>
82
- <style>
83
- *{margin:0;padding:0;box-sizing:border-box}
84
- body{font-family:system-ui,sans-serif;background:#0f172a;color:#e2e8f0;display:flex;align-items:center;justify-content:center;min-height:100vh}
85
- .card{background:#1e293b;border-radius:16px;padding:2.5rem;max-width:520px;width:90%;text-align:center;box-shadow:0 25px 50px rgba(0,0,0,.4)}
86
- h1{font-size:1.8rem;margin-bottom:.5rem}
87
- .sub{color:#94a3b8;margin-bottom:1.5rem}
88
- .btn{display:inline-block;padding:.75rem 1.5rem;background:#3b82f6;color:#fff;border-radius:8px;text-decoration:none;font-weight:600;margin:.25rem}
89
- .btn:hover{background:#2563eb}
90
- .tasks{margin-top:1.5rem;text-align:left;background:#0f172a;border-radius:8px;padding:1rem}
91
- .tasks code{color:#38bdf8}
92
- </style></head><body>
93
- <div class="card">
94
- <h1>img3txt</h1>
95
- <p class="sub">Image captioning, OCR &amp; object detection powered by FastVLM (ONNX)</p>
96
- <a class="btn" href="/docs">Swagger UI</a>
97
- <a class="btn" href="/health">Health Check</a>
98
- <div class="tasks">
99
- <p><strong>POST /caption</strong> with form fields:</p>
100
- <ul style="margin:.5rem 0 0 1.2rem;color:#94a3b8">
101
- <li><code>file</code> — image (required)</li>
102
- <li><code>task</code> — caption, detailed_caption, more_detailed_caption, ocr, ocr_with_region, object_detection, dense_region_caption, region_proposal</li>
103
- <li><code>max_tokens</code> — default 64 (smaller = faster)</li>
104
- </ul>
105
- </div>
106
- </div></body></html>`);
107
- }
108
- );
109
-
110
- app.get(
111
- "/health",
112
- {
113
- schema: {
114
- tags: ["health"],
115
- summary: "Health check",
116
- response: {
117
- 200: {
118
- type: "object",
119
- properties: {
120
- status: { type: "string", example: "ok" },
121
- model: { type: "string" },
122
- tasks: { type: "array", items: { type: "string" } },
123
- },
124
- },
125
- },
126
- },
127
- },
128
- async () => ({
129
- status: "ok",
130
- model: "onnx-community/FastVLM-0.5B-ONNX",
131
- tasks: taskEnum,
132
- })
133
- );
134
-
135
- app.post(
136
- "/caption",
137
- {
138
- schema: {
139
- tags: ["caption"],
140
- summary: "Generate caption / OCR / detection for a single image",
141
- description: `Upload an image as multipart form data. Supported tasks: ${taskEnum.join(", ")}`,
142
- consumes: ["multipart/form-data"],
143
- response: {
144
- 200: captionResponseSchema,
145
- 400: errorSchema,
146
- },
147
- },
148
- },
149
- async (req, reply) => {
150
- const data = await req.file();
151
- if (!data) {
152
- return reply.code(400).send({ error: "No file uploaded" });
153
- }
154
-
155
- const task = data.fields.task?.value || "caption";
156
- const textInput = data.fields.text?.value || null;
157
- const maxTokens = parseInt(
158
- data.fields.max_tokens?.value || String(DEFAULT_MAX_TOKENS),
159
- 10
160
- );
161
-
162
- if (!TASKS[task]) {
163
- return reply
164
- .code(400)
165
- .send({ error: `Invalid task. Choose from: ${taskEnum.join(", ")}` });
166
- }
167
-
168
- const buffer = await data.toBuffer();
169
- const result = await generateCaption(buffer, task, textInput, maxTokens);
170
-
171
- return { task, result };
172
- }
173
- );
174
-
175
- app.post(
176
- "/caption/batch",
177
- {
178
- schema: {
179
- tags: ["caption"],
180
- summary: "Generate captions for multiple images",
181
- description:
182
- "Upload multiple images as multipart form data. All images share the same task and settings.",
183
- consumes: ["multipart/form-data"],
184
- response: {
185
- 200: batchResponseSchema,
186
- 400: errorSchema,
187
- },
188
- },
189
- },
190
- async (req, reply) => {
191
- const parts = await req.parts();
192
- const files = [];
193
- let task = "caption";
194
- let textInput = null;
195
- let maxTokens = DEFAULT_MAX_TOKENS;
196
-
197
- for await (const part of parts) {
198
- if (part.type === "file") {
199
- files.push({ filename: part.filename, buffer: await part.toBuffer() });
200
- } else if (part.fieldname === "task") {
201
- task = part.value;
202
- } else if (part.fieldname === "text") {
203
- textInput = part.value;
204
- } else if (part.fieldname === "max_tokens") {
205
- maxTokens = parseInt(part.value, 10);
206
- }
207
- }
208
-
209
- if (files.length === 0) {
210
- return reply.code(400).send({ error: "No files uploaded" });
211
- }
212
- if (!TASKS[task]) {
213
- return reply
214
- .code(400)
215
- .send({ error: `Invalid task. Choose from: ${taskEnum.join(", ")}` });
216
- }
217
-
218
- const results = [];
219
- for (const f of files) {
220
- const result = await generateCaption(f.buffer, task, textInput, maxTokens);
221
- results.push({ filename: f.filename, task, result });
222
- }
223
-
224
- return { results };
225
- }
226
- );
227
-
228
- // --- Start ---
229
- const PORT = process.env.PORT || 7860;
230
-
231
- // Pre-load model then start server
232
- await loadModel();
233
- app.listen({ host: "0.0.0.0", port: PORT }, (err) => {
234
- if (err) {
235
- app.log.error(err);
236
- process.exit(1);
237
- }
238
- });
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/server.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from typing import Any
5
+
6
+ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.responses import HTMLResponse
9
+
10
+ from .model import MODEL_ID, TASKS, DEFAULT_MAX_TOKENS, generate_caption, load_model
11
+
12
+ app = FastAPI(
13
+ title="img3txt - Florence-2 API",
14
+ description="Generate captions, OCR, object detection and more from images using Florence-2.",
15
+ version="1.0.0",
16
+ )
17
+
18
+ app.add_middleware(
19
+ CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_credentials=True,
22
+ allow_methods=["*"],
23
+ allow_headers=["*"],
24
+ )
25
+
26
+
27
+ @app.on_event("startup")
28
+ def warmup_model() -> None:
29
+ load_model()
30
+
31
+
32
+ @app.get("/", response_class=HTMLResponse, include_in_schema=False)
33
+ def root() -> str:
34
+ return """<!DOCTYPE html>
35
+ <html lang=\"en\"><head><meta charset=\"utf-8\">
36
+ <meta name=\"viewport\" content=\"width=device-width,initial-scale=1\">
37
+ <title>img3txt - Florence-2 Image Captioning API</title>
38
+ <style>
39
+ *{margin:0;padding:0;box-sizing:border-box}
40
+ body{font-family:system-ui,sans-serif;background:#0f172a;color:#e2e8f0;display:flex;align-items:center;justify-content:center;min-height:100vh}
41
+ .card{background:#1e293b;border-radius:16px;padding:2.5rem;max-width:520px;width:90%;text-align:center;box-shadow:0 25px 50px rgba(0,0,0,.4)}
42
+ h1{font-size:1.8rem;margin-bottom:.5rem}
43
+ .sub{color:#94a3b8;margin-bottom:1.5rem}
44
+ .btn{display:inline-block;padding:.75rem 1.5rem;background:#3b82f6;color:#fff;border-radius:8px;text-decoration:none;font-weight:600;margin:.25rem}
45
+ .btn:hover{background:#2563eb}
46
+ .tasks{margin-top:1.5rem;text-align:left;background:#0f172a;border-radius:8px;padding:1rem}
47
+ .tasks code{color:#38bdf8}
48
+ </style></head><body>
49
+ <div class=\"card\">
50
+ <h1>img3txt</h1>
51
+ <p class=\"sub\">Image captioning, OCR &amp; object detection powered by Florence-2</p>
52
+ <a class=\"btn\" href=\"/docs\">Swagger UI</a>
53
+ <a class=\"btn\" href=\"/health\">Health Check</a>
54
+ <div class=\"tasks\">
55
+ <p><strong>POST /caption</strong> with form fields:</p>
56
+ <ul style=\"margin:.5rem 0 0 1.2rem;color:#94a3b8\">
57
+ <li><code>file</code> - image (required)</li>
58
+ <li><code>task</code> - caption, detailed_caption, more_detailed_caption, ocr, ocr_with_region, object_detection, dense_region_caption, region_proposal</li>
59
+ <li><code>max_tokens</code> - default 64 (smaller = faster)</li>
60
+ </ul>
61
+ </div>
62
+ </div></body></html>"""
63
+
64
+
65
+ @app.get("/health")
66
+ def health() -> dict[str, Any]:
67
+ return {"status": "ok", "model": MODEL_ID, "tasks": list(TASKS.keys())}
68
+
69
+
70
+ @app.post("/caption")
71
+ async def caption(
72
+ file: UploadFile = File(...),
73
+ task: str = Form("caption"),
74
+ text: str | None = Form(None),
75
+ max_tokens: int = Form(DEFAULT_MAX_TOKENS),
76
+ ) -> dict[str, Any]:
77
+ if task not in TASKS:
78
+ raise HTTPException(status_code=400, detail=f"Invalid task. Choose from: {', '.join(TASKS.keys())}")
79
+
80
+ image_bytes = await file.read()
81
+ if not image_bytes:
82
+ raise HTTPException(status_code=400, detail="Empty file uploaded")
83
+
84
+ result = generate_caption(image_bytes, task, text, max_tokens)
85
+ return {"task": task, "result": result}
86
+
87
+
88
+ @app.post("/caption/batch")
89
+ async def caption_batch(
90
+ files: list[UploadFile] = File(...),
91
+ task: str = Form("caption"),
92
+ text: str | None = Form(None),
93
+ max_tokens: int = Form(DEFAULT_MAX_TOKENS),
94
+ ) -> dict[str, Any]:
95
+ if task not in TASKS:
96
+ raise HTTPException(status_code=400, detail=f"Invalid task. Choose from: {', '.join(TASKS.keys())}")
97
+
98
+ results: list[dict[str, Any]] = []
99
+ for upload in files:
100
+ image_bytes = await upload.read()
101
+ if not image_bytes:
102
+ continue
103
+ result = generate_caption(image_bytes, task, text, max_tokens)
104
+ results.append({"filename": upload.filename, "task": task, "result": result})
105
+
106
+ if not results:
107
+ raise HTTPException(status_code=400, detail="No files uploaded")
108
+
109
+ return {"results": results}
110
+
111
+
112
+ if __name__ == "__main__":
113
+ import uvicorn
114
+
115
+ port = int(os.getenv("PORT", "7860"))
116
+ uvicorn.run("src.server:app", host="0.0.0.0", port=port)