videopix commited on
Commit
256d97e
·
verified ·
1 Parent(s): 2afc806

Update app_working_api.py

Browse files
Files changed (1) hide show
  1. app_working_api.py +114 -24
app_working_api.py CHANGED
@@ -3,43 +3,53 @@ import asyncio
3
  import threading
4
  import time
5
  from fastapi import FastAPI, File, UploadFile
6
- from fastapi.responses import JSONResponse
7
  from PIL import Image
8
  import torch
9
  from transformers import AutoProcessor, AutoModelForCausalLM
10
  import requests
11
 
12
- app = FastAPI(title="Image Caption API")
 
 
 
13
 
14
- # Load model once at startup
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
 
17
- processor = AutoProcessor.from_pretrained(
18
- "microsoft/Florence-2-base",
19
- trust_remote_code=True
20
- )
21
 
22
- model = AutoModelForCausalLM.from_pretrained(
23
- "microsoft/Florence-2-base",
24
- trust_remote_code=True
25
- ).to(device).eval()
26
 
27
- # A lock to allow multiple requests safely
28
- inference_lock = asyncio.Lock()
 
29
 
 
 
 
 
 
 
 
 
 
30
 
31
- def caption_image(image: Image.Image) -> str:
 
 
32
  inputs = processor(
33
  text="<MORE_DETAILED_CAPTION>",
34
  images=image,
35
- return_tensors="pt",
36
  ).to(device)
37
 
38
  output_ids = model.generate(
39
  input_ids=inputs["input_ids"],
40
  pixel_values=inputs["pixel_values"],
41
  max_new_tokens=256,
42
- num_beams=3,
43
  )
44
 
45
  decoded = processor.batch_decode(output_ids, skip_special_tokens=False)[0]
@@ -47,22 +57,28 @@ def caption_image(image: Image.Image) -> str:
47
  parsed = processor.post_process_generation(
48
  decoded,
49
  task="<MORE_DETAILED_CAPTION>",
50
- image_size=(image.width, image.height),
51
  )
52
 
53
  return parsed["<MORE_DETAILED_CAPTION>"]
54
 
55
 
 
 
 
56
  @app.post("/img2caption")
57
  async def img2caption(file: UploadFile = File(...)):
58
  try:
59
- # Read image
 
 
 
 
60
  data = await file.read()
61
  image = Image.open(io.BytesIO(data)).convert("RGB")
62
 
63
- # Protect inference in async server
64
- async with inference_lock:
65
- caption = caption_image(image)
66
 
67
  return {"caption": caption}
68
 
@@ -70,6 +86,80 @@ async def img2caption(file: UploadFile = File(...)):
70
  return JSONResponse({"error": str(e)}, status_code=500)
71
 
72
 
73
- @app.get("/health")
74
- async def health():
75
- return {"status": "ok"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import threading
4
  import time
5
  from fastapi import FastAPI, File, UploadFile
6
+ from fastapi.responses import JSONResponse, HTMLResponse
7
  from PIL import Image
8
  import torch
9
  from transformers import AutoProcessor, AutoModelForCausalLM
10
  import requests
11
 
12
+ # ---------------------------------------------------
13
+ # FastAPI App
14
+ # ---------------------------------------------------
15
+ app = FastAPI(title="Florence Image Caption API")
16
 
 
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
 
19
+ # Lazy load model on first request (prevents HF timeout)
20
+ processor = None
21
+ model = None
22
+ model_lock = asyncio.Lock()
23
 
 
 
 
 
24
 
25
+ async def load_model():
26
+ """Load Florence model only when first needed."""
27
+ global processor, model
28
 
29
+ if model is None:
30
+ processor = AutoProcessor.from_pretrained(
31
+ "microsoft/Florence-2-base",
32
+ trust_remote_code=True
33
+ )
34
+ model = AutoModelForCausalLM.from_pretrained(
35
+ "microsoft/Florence-2-base",
36
+ trust_remote_code=True
37
+ ).to(device).eval()
38
 
39
+
40
+ def run_caption(image: Image.Image) -> str:
41
+ """Perform caption generation."""
42
  inputs = processor(
43
  text="<MORE_DETAILED_CAPTION>",
44
  images=image,
45
+ return_tensors="pt"
46
  ).to(device)
47
 
48
  output_ids = model.generate(
49
  input_ids=inputs["input_ids"],
50
  pixel_values=inputs["pixel_values"],
51
  max_new_tokens=256,
52
+ num_beams=3
53
  )
54
 
55
  decoded = processor.batch_decode(output_ids, skip_special_tokens=False)[0]
 
57
  parsed = processor.post_process_generation(
58
  decoded,
59
  task="<MORE_DETAILED_CAPTION>",
60
+ image_size=(image.width, image.height)
61
  )
62
 
63
  return parsed["<MORE_DETAILED_CAPTION>"]
64
 
65
 
66
+ # ---------------------------------------------------
67
+ # API Endpoint
68
+ # ---------------------------------------------------
69
  @app.post("/img2caption")
70
  async def img2caption(file: UploadFile = File(...)):
71
  try:
72
+ # Ensure model is loaded
73
+ async with model_lock:
74
+ await load_model()
75
+
76
+ # Read and convert image
77
  data = await file.read()
78
  image = Image.open(io.BytesIO(data)).convert("RGB")
79
 
80
+ # Caption
81
+ caption = run_caption(image)
 
82
 
83
  return {"caption": caption}
84
 
 
86
  return JSONResponse({"error": str(e)}, status_code=500)
87
 
88
 
89
+ # ---------------------------------------------------
90
+ # Simple HTML UI
91
+ # ---------------------------------------------------
92
+ @app.get("/", response_class=HTMLResponse)
93
+ def ui():
94
+ return """
95
+ <!DOCTYPE html>
96
+ <html>
97
+ <head>
98
+ <title>Image Caption Generator</title>
99
+ <style>
100
+ body { font-family: Arial; max-width: 650px; margin: 40px auto; }
101
+ h2 { text-align: center; }
102
+ #preview {
103
+ width: 100%; margin-top: 15px; display: none;
104
+ border-radius: 8px;
105
+ }
106
+ #captionBox {
107
+ margin-top: 20px; padding: 15px;
108
+ background: #eee; border-radius: 6px; display: none;
109
+ }
110
+ button {
111
+ padding: 12px; width: 100%; margin-top: 10px;
112
+ background: #4A90E2; color: white; border: none;
113
+ border-radius: 6px; cursor: pointer; font-size: 16px;
114
+ }
115
+ button:hover { background: #357ABD; }
116
+ </style>
117
+ </head>
118
+ <body>
119
+ <h2>Image Caption Generator</h2>
120
+ <input type="file" id="imageInput" accept="image/*">
121
+ <img id="preview">
122
+ <button onclick="generateCaption()">Generate Caption</button>
123
+ <div id="captionBox"></div>
124
+ <script>
125
+ const imageInput = document.getElementById("imageInput");
126
+ const preview = document.getElementById("preview");
127
+ const captionBox = document.getElementById("captionBox");
128
+ imageInput.onchange = () => {
129
+ const f = imageInput.files[0];
130
+ if (f) {
131
+ preview.src = URL.createObjectURL(f);
132
+ preview.style.display = "block";
133
+ }
134
+ };
135
+ async function generateCaption() {
136
+ const f = imageInput.files[0];
137
+ if (!f) {
138
+ alert("Upload an image first");
139
+ return;
140
+ }
141
+ const form = new FormData();
142
+ form.append("file", f);
143
+ captionBox.style.display = "block";
144
+ captionBox.innerHTML = "Generating caption...";
145
+ const res = await fetch("/img2caption", {
146
+ method: "POST",
147
+ body: form
148
+ });
149
+ const data = await res.json();
150
+ captionBox.innerHTML = data.caption || data.error;
151
+ }
152
+ </script>
153
+ </body>
154
+ </html>
155
+ """
156
+
157
+
158
+ def keep_alive():
159
+ pass
160
+
161
+ if __name__ == "__main__":
162
+ import uvicorn
163
+ print("🚀 Launching Fast img2caption API")
164
+ keep_alive()
165
+ uvicorn.run(app, host="0.0.0.0", port=7860)