videopix commited on
Commit
0ecd601
·
verified ·
1 Parent(s): 8775421

Update app_working_api.py

Browse files
Files changed (1) hide show
  1. app_working_api.py +50 -239
app_working_api.py CHANGED
@@ -1,264 +1,75 @@
1
- import uvicorn
2
- import base64
3
  import io
4
- import numpy as np
5
- from fastapi import FastAPI
6
- from fastapi.responses import HTMLResponse
7
- from pydantic import BaseModel
8
- from PIL import Image, ImageOps, ImageEnhance
 
9
  import torch
10
- from transformers import BlipProcessor, BlipForConditionalGeneration
11
- import easyocr
12
- import os
13
 
14
- # ------------------------
15
- # HF Token
16
- # ------------------------
17
- HF_TOKEN = os.getenv("HF_TOKEN")
18
 
19
- # ------------------------
20
- # Load BLIP model
21
- # ------------------------
22
- device = torch.device("cpu")
23
 
24
- processor = BlipProcessor.from_pretrained(
25
- "Salesforce/blip-image-captioning-large",
26
- use_auth_token=HF_TOKEN
27
  )
28
 
29
- model = BlipForConditionalGeneration.from_pretrained(
30
- "Salesforce/blip-image-captioning-large",
31
- use_auth_token=HF_TOKEN
32
- ).to(device)
33
 
34
- model.eval()
 
35
 
36
- # ------------------------
37
- # Load OCR Reader
38
- # ------------------------
39
- ocr_reader = easyocr.Reader(
40
- ["en"],
41
- gpu=False,
42
- recog_network="english_g2" # BEST for mixed fonts / stylized text
43
- )
44
-
45
- # ------------------------
46
- # FastAPI App
47
- # ------------------------
48
- app = FastAPI()
49
-
50
-
51
- class ImageRequest(BaseModel):
52
- image_base64: str
53
-
54
-
55
- # ------------------------
56
- # Improve OCR by preprocessing image
57
- # ------------------------
58
- def preprocess_for_ocr(img: Image.Image) -> np.ndarray:
59
- # Convert to grayscale
60
- gray = ImageOps.grayscale(img)
61
-
62
- # Increase contrast
63
- enhancer = ImageEnhance.Contrast(gray)
64
- gray = enhancer.enhance(2.0)
65
-
66
- # Increase brightness slightly
67
- enhancer = ImageEnhance.Brightness(gray)
68
- gray = enhancer.enhance(1.1)
69
 
70
- # Convert to numpy
71
- return np.array(gray)
 
 
 
 
72
 
73
-
74
- # ------------------------
75
- # OCR Function (improved)
76
- # ------------------------
77
- def extract_text(img: Image.Image) -> str:
78
- pre_img = preprocess_for_ocr(img)
79
-
80
- result = ocr_reader.readtext(
81
- pre_img,
82
- detail=0,
83
- paragraph=True
84
  )
85
 
86
- return "\n".join(result) if result else "No text detected."
87
-
88
-
89
- # ------------------------
90
- # Caption Function (clean output)
91
- # ------------------------
92
- def create_caption(img: Image.Image) -> str:
93
- inputs = processor(img, return_tensors="pt").to(device)
94
 
95
- with torch.no_grad():
96
- out = model.generate(
97
- **inputs,
98
- max_length=150,
99
- min_length=30,
100
- num_beams=5,
101
- repetition_penalty=1.1,
102
- length_penalty=1.0,
103
- temperature=0.7
104
- )
105
-
106
- caption = processor.decode(out[0], skip_special_tokens=True)
107
-
108
- # REMOVE prompt words if BLIP inserted them
109
- caption = caption.replace("describe this image", "").strip()
110
- caption = caption.replace("describe the image", "").strip()
111
 
112
- return caption
113
 
114
 
115
- # ------------------------
116
- # API Endpoint: /img2caption
117
- # ------------------------
118
  @app.post("/img2caption")
119
- async def img2caption(payload: ImageRequest):
120
  try:
121
- img_bytes = base64.b64decode(payload.image_base64)
122
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
123
-
124
- caption = create_caption(img)
125
- return {"caption": caption}
126
-
127
- except Exception as e:
128
- return {"error": str(e)}
129
 
 
 
 
130
 
131
- # ------------------------
132
- # API Endpoint: /ocr
133
- # ------------------------
134
- @app.post("/ocr")
135
- async def ocr_endpoint(payload: ImageRequest):
136
- try:
137
- img_bytes = base64.b64decode(payload.image_base64)
138
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
139
-
140
- text = extract_text(img)
141
- return {"ocr_text": text}
142
-
143
- except Exception as e:
144
- return {"error": str(e)}
145
-
146
-
147
- # ------------------------
148
- # API Endpoint: /ocr
149
- # ------------------------
150
- @app.post("/ocr")
151
- async def ocr_endpoint(payload: ImageRequest):
152
- try:
153
- img_bytes = base64.b64decode(payload.image_base64)
154
- img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
155
-
156
- text = extract_text(img)
157
- return {"ocr_text": text}
158
 
159
  except Exception as e:
160
- return {"error": str(e)}
161
-
162
-
163
- # ------------------------
164
- # UI Endpoint: /
165
- # ------------------------
166
- @app.get("/", response_class=HTMLResponse)
167
- async def ui_page():
168
- return """
169
- <!DOCTYPE html>
170
- <html>
171
- <head>
172
- <title>Image Caption + OCR</title>
173
- <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet">
174
- <style>
175
- body { background: #f5f7fa; }
176
- .container { max-width: 650px; margin-top: 60px; }
177
- #preview {
178
- width: 100%; border-radius: 10px; margin-top: 20px; display: none;
179
- }
180
- #caption-box {
181
- font-size: 18px; margin-top: 20px; padding: 15px;
182
- border-radius: 8px; background: #e3f2fd; display: none;
183
- }
184
- </style>
185
- </head>
186
- <body>
187
- <div class="container">
188
- <div class="card shadow-sm">
189
- <div class="card-body">
190
- <h3 class="text-center mb-3">Image Caption + OCR Extractor</h3>
191
- <input type="file" class="form-control" id="imageInput" accept="image/*">
192
- <img id="preview">
193
- <div class="d-grid gap-2 mt-3">
194
- <button class="btn btn-primary btn-lg" onclick="sendCaption()">
195
- Generate Detailed Caption
196
- </button>
197
- <button class="btn btn-success btn-lg" onclick="sendOCR()">
198
- Extract Text (OCR)
199
- </button>
200
- </div>
201
- <div id="caption-box"></div>
202
- </div>
203
- </div>
204
- </div>
205
- <script>
206
- let base64Image = "";
207
- document.getElementById("imageInput").addEventListener("change", function(event){
208
- const file = event.target.files[0];
209
- const reader = new FileReader();
210
- reader.onload = function(e){
211
- base64Image = e.target.result.split(",")[1];
212
- const preview = document.getElementById("preview");
213
- preview.src = e.target.result;
214
- preview.style.display = "block";
215
- };
216
- reader.readAsDataURL(file);
217
- });
218
- async function sendCaption() {
219
- if (!base64Image) {
220
- alert("Please upload an image first.");
221
- return;
222
- }
223
- const box = document.getElementById("caption-box");
224
- box.style.display = "block";
225
- box.innerHTML = "Generating caption...";
226
- const res = await fetch("/img2caption", {
227
- method: "POST",
228
- headers: { "Content-Type": "application/json" },
229
- body: JSON.stringify({ image_base64: base64Image })
230
- });
231
- const data = await res.json();
232
- box.innerHTML = data.caption
233
- ? "<strong>Caption:</strong> " + data.caption
234
- : "<strong>Error:</strong> " + data.error;
235
- }
236
- async function sendOCR() {
237
- if (!base64Image) {
238
- alert("Please upload an image first.");
239
- return;
240
- }
241
- const box = document.getElementById("caption-box");
242
- box.style.display = "block";
243
- box.innerHTML = "Extracting text...";
244
- const res = await fetch("/ocr", {
245
- method: "POST",
246
- headers: { "Content-Type": "application/json" },
247
- body: JSON.stringify({ image_base64: base64Image })
248
- });
249
- const data = await res.json();
250
- box.innerHTML = data.ocr_text
251
- ? "<strong>OCR Result:</strong><br>" + data.ocr_text.replaceAll("\\n", "<br>")
252
- : "<strong>Error:</strong> " + data.error;
253
- }
254
- </script>
255
- </body>
256
- </html>
257
- """
258
 
259
 
260
- # -------------------------
261
- # Run App
262
- # -------------------------
263
- if __name__ == "__main__":
264
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
1
  import io
2
+ import asyncio
3
+ import threading
4
+ import time
5
+ from fastapi import FastAPI, File, UploadFile
6
+ from fastapi.responses import JSONResponse
7
+ from PIL import Image
8
  import torch
9
+ from transformers import AutoProcessor, AutoModelForCausalLM
10
+ import requests
 
11
 
12
+ app = FastAPI(title="Image Caption API")
 
 
 
13
 
14
+ # Load model once at startup
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
16
 
17
+ processor = AutoProcessor.from_pretrained(
18
+ "microsoft/Florence-2-base",
19
+ trust_remote_code=True
20
  )
21
 
22
+ model = AutoModelForCausalLM.from_pretrained(
23
+ "microsoft/Florence-2-base",
24
+ trust_remote_code=True
25
+ ).to(device).eval()
26
 
27
+ # A lock to allow multiple requests safely
28
+ inference_lock = asyncio.Lock()
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ def caption_image(image: Image.Image) -> str:
32
+ inputs = processor(
33
+ text="<MORE_DETAILED_CAPTION>",
34
+ images=image,
35
+ return_tensors="pt",
36
+ ).to(device)
37
 
38
+ output_ids = model.generate(
39
+ input_ids=inputs["input_ids"],
40
+ pixel_values=inputs["pixel_values"],
41
+ max_new_tokens=256,
42
+ num_beams=3,
 
 
 
 
 
 
43
  )
44
 
45
+ decoded = processor.batch_decode(output_ids, skip_special_tokens=False)[0]
 
 
 
 
 
 
 
46
 
47
+ parsed = processor.post_process_generation(
48
+ decoded,
49
+ task="<MORE_DETAILED_CAPTION>",
50
+ image_size=(image.width, image.height),
51
+ )
 
 
 
 
 
 
 
 
 
 
 
52
 
53
+ return parsed["<MORE_DETAILED_CAPTION>"]
54
 
55
 
 
 
 
56
  @app.post("/img2caption")
57
+ async def img2caption(file: UploadFile = File(...)):
58
  try:
59
+ # Read image
60
+ data = await file.read()
61
+ image = Image.open(io.BytesIO(data)).convert("RGB")
 
 
 
 
 
62
 
63
+ # Protect inference in async server
64
+ async with inference_lock:
65
+ caption = caption_image(image)
66
 
67
+ return {"caption": caption}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
  except Exception as e:
70
+ return JSONResponse({"error": str(e)}, status_code=500)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
 
73
+ @app.get("/health")
74
+ async def health():
75
+ return {"status": "ok"}