videopix commited on
Commit
8dfd66b
·
verified ·
1 Parent(s): 9a1b65b

Upload app_working_api.py

Browse files
Files changed (1) hide show
  1. app_working_api.py +264 -0
app_working_api.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ import base64
3
+ import io
4
+ import numpy as np
5
+ from fastapi import FastAPI
6
+ from fastapi.responses import HTMLResponse
7
+ from pydantic import BaseModel
8
+ from PIL import Image, ImageOps, ImageEnhance
9
+ import torch
10
+ from transformers import BlipProcessor, BlipForConditionalGeneration
11
+ import easyocr
12
+ import os
13
+
14
+ # ------------------------
15
+ # HF Token
16
+ # ------------------------
17
+ HF_TOKEN = os.getenv("HF_TOKEN")
18
+
19
+ # ------------------------
20
+ # Load BLIP model
21
+ # ------------------------
22
+ device = torch.device("cpu")
23
+
24
+ processor = BlipProcessor.from_pretrained(
25
+ "Salesforce/blip-image-captioning-large",
26
+ use_auth_token=HF_TOKEN
27
+ )
28
+
29
+ model = BlipForConditionalGeneration.from_pretrained(
30
+ "Salesforce/blip-image-captioning-large",
31
+ use_auth_token=HF_TOKEN
32
+ ).to(device)
33
+
34
+ model.eval()
35
+
36
+ # ------------------------
37
+ # Load OCR Reader
38
+ # ------------------------
39
+ ocr_reader = easyocr.Reader(
40
+ ["en"],
41
+ gpu=False,
42
+ recog_network="english_g2" # BEST for mixed fonts / stylized text
43
+ )
44
+
45
+ # ------------------------
46
+ # FastAPI App
47
+ # ------------------------
48
+ app = FastAPI()
49
+
50
+
51
+ class ImageRequest(BaseModel):
52
+ image_base64: str
53
+
54
+
55
+ # ------------------------
56
+ # Improve OCR by preprocessing image
57
+ # ------------------------
58
+ def preprocess_for_ocr(img: Image.Image) -> np.ndarray:
59
+ # Convert to grayscale
60
+ gray = ImageOps.grayscale(img)
61
+
62
+ # Increase contrast
63
+ enhancer = ImageEnhance.Contrast(gray)
64
+ gray = enhancer.enhance(2.0)
65
+
66
+ # Increase brightness slightly
67
+ enhancer = ImageEnhance.Brightness(gray)
68
+ gray = enhancer.enhance(1.1)
69
+
70
+ # Convert to numpy
71
+ return np.array(gray)
72
+
73
+
74
+ # ------------------------
75
+ # OCR Function (improved)
76
+ # ------------------------
77
+ def extract_text(img: Image.Image) -> str:
78
+ pre_img = preprocess_for_ocr(img)
79
+
80
+ result = ocr_reader.readtext(
81
+ pre_img,
82
+ detail=0,
83
+ paragraph=True
84
+ )
85
+
86
+ return "\n".join(result) if result else "No text detected."
87
+
88
+
89
+ # ------------------------
90
+ # Caption Function (clean output)
91
+ # ------------------------
92
+ def create_caption(img: Image.Image) -> str:
93
+ inputs = processor(img, return_tensors="pt").to(device)
94
+
95
+ with torch.no_grad():
96
+ out = model.generate(
97
+ **inputs,
98
+ max_length=150,
99
+ min_length=30,
100
+ num_beams=5,
101
+ repetition_penalty=1.1,
102
+ length_penalty=1.0,
103
+ temperature=0.7
104
+ )
105
+
106
+ caption = processor.decode(out[0], skip_special_tokens=True)
107
+
108
+ # REMOVE prompt words if BLIP inserted them
109
+ caption = caption.replace("describe this image", "").strip()
110
+ caption = caption.replace("describe the image", "").strip()
111
+
112
+ return caption
113
+
114
+
115
+ # ------------------------
116
+ # API Endpoint: /img2caption
117
+ # ------------------------
118
+ @app.post("/img2caption")
119
+ async def img2caption(payload: ImageRequest):
120
+ try:
121
+ img_bytes = base64.b64decode(payload.image_base64)
122
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
123
+
124
+ caption = create_caption(img)
125
+ return {"caption": caption}
126
+
127
+ except Exception as e:
128
+ return {"error": str(e)}
129
+
130
+
131
+ # ------------------------
132
+ # API Endpoint: /ocr
133
+ # ------------------------
134
+ @app.post("/ocr")
135
+ async def ocr_endpoint(payload: ImageRequest):
136
+ try:
137
+ img_bytes = base64.b64decode(payload.image_base64)
138
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
139
+
140
+ text = extract_text(img)
141
+ return {"ocr_text": text}
142
+
143
+ except Exception as e:
144
+ return {"error": str(e)}
145
+
146
+
147
+ # ------------------------
148
+ # API Endpoint: /ocr
149
+ # ------------------------
150
+ @app.post("/ocr")
151
+ async def ocr_endpoint(payload: ImageRequest):
152
+ try:
153
+ img_bytes = base64.b64decode(payload.image_base64)
154
+ img = Image.open(io.BytesIO(img_bytes)).convert("RGB")
155
+
156
+ text = extract_text(img)
157
+ return {"ocr_text": text}
158
+
159
+ except Exception as e:
160
+ return {"error": str(e)}
161
+
162
+
163
+ # ------------------------
164
+ # UI Endpoint: /
165
+ # ------------------------
166
+ @app.get("/", response_class=HTMLResponse)
167
+ async def ui_page():
168
+ return """
169
+ <!DOCTYPE html>
170
+ <html>
171
+ <head>
172
+ <title>Image Caption + OCR</title>
173
+ <link href="https://cdn.jsdelivr.net/npm/bootstrap@5.3.2/dist/css/bootstrap.min.css" rel="stylesheet">
174
+ <style>
175
+ body { background: #f5f7fa; }
176
+ .container { max-width: 650px; margin-top: 60px; }
177
+ #preview {
178
+ width: 100%; border-radius: 10px; margin-top: 20px; display: none;
179
+ }
180
+ #caption-box {
181
+ font-size: 18px; margin-top: 20px; padding: 15px;
182
+ border-radius: 8px; background: #e3f2fd; display: none;
183
+ }
184
+ </style>
185
+ </head>
186
+ <body>
187
+ <div class="container">
188
+ <div class="card shadow-sm">
189
+ <div class="card-body">
190
+ <h3 class="text-center mb-3">Image Caption + OCR Extractor</h3>
191
+ <input type="file" class="form-control" id="imageInput" accept="image/*">
192
+ <img id="preview">
193
+ <div class="d-grid gap-2 mt-3">
194
+ <button class="btn btn-primary btn-lg" onclick="sendCaption()">
195
+ Generate Detailed Caption
196
+ </button>
197
+ <button class="btn btn-success btn-lg" onclick="sendOCR()">
198
+ Extract Text (OCR)
199
+ </button>
200
+ </div>
201
+ <div id="caption-box"></div>
202
+ </div>
203
+ </div>
204
+ </div>
205
+ <script>
206
+ let base64Image = "";
207
+ document.getElementById("imageInput").addEventListener("change", function(event){
208
+ const file = event.target.files[0];
209
+ const reader = new FileReader();
210
+ reader.onload = function(e){
211
+ base64Image = e.target.result.split(",")[1];
212
+ const preview = document.getElementById("preview");
213
+ preview.src = e.target.result;
214
+ preview.style.display = "block";
215
+ };
216
+ reader.readAsDataURL(file);
217
+ });
218
+ async function sendCaption() {
219
+ if (!base64Image) {
220
+ alert("Please upload an image first.");
221
+ return;
222
+ }
223
+ const box = document.getElementById("caption-box");
224
+ box.style.display = "block";
225
+ box.innerHTML = "Generating caption...";
226
+ const res = await fetch("/img2caption", {
227
+ method: "POST",
228
+ headers: { "Content-Type": "application/json" },
229
+ body: JSON.stringify({ image_base64: base64Image })
230
+ });
231
+ const data = await res.json();
232
+ box.innerHTML = data.caption
233
+ ? "<strong>Caption:</strong> " + data.caption
234
+ : "<strong>Error:</strong> " + data.error;
235
+ }
236
+ async function sendOCR() {
237
+ if (!base64Image) {
238
+ alert("Please upload an image first.");
239
+ return;
240
+ }
241
+ const box = document.getElementById("caption-box");
242
+ box.style.display = "block";
243
+ box.innerHTML = "Extracting text...";
244
+ const res = await fetch("/ocr", {
245
+ method: "POST",
246
+ headers: { "Content-Type": "application/json" },
247
+ body: JSON.stringify({ image_base64: base64Image })
248
+ });
249
+ const data = await res.json();
250
+ box.innerHTML = data.ocr_text
251
+ ? "<strong>OCR Result:</strong><br>" + data.ocr_text.replaceAll("\\n", "<br>")
252
+ : "<strong>Error:</strong> " + data.error;
253
+ }
254
+ </script>
255
+ </body>
256
+ </html>
257
+ """
258
+
259
+
260
+ # -------------------------
261
+ # Run App
262
+ # -------------------------
263
+ if __name__ == "__main__":
264
+ uvicorn.run(app, host="0.0.0.0", port=7860)