videopix commited on
Commit
b4b4755
·
verified ·
1 Parent(s): 866a0cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +160 -17
app.py CHANGED
@@ -1,12 +1,22 @@
1
- import gradio as gr
2
- import torch
 
 
 
 
 
3
  from PIL import Image
 
4
  from transformers import AutoProcessor, AutoModelForCausalLM
 
5
 
6
- # Choose device
 
 
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
- # Load Florence-2 Base model
10
  processor = AutoProcessor.from_pretrained(
11
  "microsoft/Florence-2-base",
12
  trust_remote_code=True
@@ -17,10 +27,10 @@ model = AutoModelForCausalLM.from_pretrained(
17
  trust_remote_code=True
18
  ).to(device).eval()
19
 
20
- def generate_caption(image):
21
- if not isinstance(image, Image.Image):
22
- image = Image.fromarray(image)
23
 
 
24
  inputs = processor(
25
  text="<MORE_DETAILED_CAPTION>",
26
  images=image,
@@ -31,7 +41,7 @@ def generate_caption(image):
31
  input_ids=inputs["input_ids"],
32
  pixel_values=inputs["pixel_values"],
33
  max_new_tokens=256,
34
- num_beams=3,
35
  )
36
 
37
  decoded = processor.batch_decode(output_ids, skip_special_tokens=False)[0]
@@ -44,13 +54,146 @@ def generate_caption(image):
44
 
45
  return parsed["<MORE_DETAILED_CAPTION>"]
46
 
47
- # Gradio interface
48
- io = gr.Interface(
49
- fn=generate_caption,
50
- inputs=gr.Image(label="Upload Image"),
51
- outputs=gr.Textbox(label="Generated Caption", lines=3),
52
- title="Image to Caption Generator",
53
- description="Upload an image and get a detailed AI-generated caption."
54
- )
55
 
56
- io.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import asyncio
3
+ import threading
4
+ import time
5
+ from fastapi import FastAPI, File, UploadFile
6
+ from fastapi.responses import JSONResponse, HTMLResponse
7
+ from fastapi.staticfiles import StaticFiles
8
  from PIL import Image
9
+ import torch
10
  from transformers import AutoProcessor, AutoModelForCausalLM
11
+ import requests
12
 
13
+ app = FastAPI(title="Image Caption API")
14
+
15
+ # -------------------------
16
+ # Load Model
17
+ # -------------------------
18
  device = "cuda" if torch.cuda.is_available() else "cpu"
19
 
 
20
  processor = AutoProcessor.from_pretrained(
21
  "microsoft/Florence-2-base",
22
  trust_remote_code=True
 
27
  trust_remote_code=True
28
  ).to(device).eval()
29
 
30
+ inference_lock = asyncio.Lock()
31
+
 
32
 
33
+ def caption_image(image: Image.Image) -> str:
34
  inputs = processor(
35
  text="<MORE_DETAILED_CAPTION>",
36
  images=image,
 
41
  input_ids=inputs["input_ids"],
42
  pixel_values=inputs["pixel_values"],
43
  max_new_tokens=256,
44
+ num_beams=3
45
  )
46
 
47
  decoded = processor.batch_decode(output_ids, skip_special_tokens=False)[0]
 
54
 
55
  return parsed["<MORE_DETAILED_CAPTION>"]
56
 
 
 
 
 
 
 
 
 
57
 
58
+ # -------------------------
59
+ # API Endpoint
60
+ # -------------------------
61
+ @app.post("/img2caption")
62
+ async def img2caption(file: UploadFile = File(...)):
63
+ try:
64
+ data = await file.read()
65
+ image = Image.open(io.BytesIO(data)).convert("RGB")
66
+
67
+ async with inference_lock:
68
+ caption = caption_image(image)
69
+
70
+ return {"caption": caption}
71
+
72
+ except Exception as e:
73
+ return JSONResponse({"error": str(e)}, status_code=500)
74
+
75
+
76
+ # -------------------------
77
+ # HTML UI
78
+ # -------------------------
79
+ @app.get("/", response_class=HTMLResponse)
80
+ def ui():
81
+ return """
82
+ <!DOCTYPE html>
83
+ <html>
84
+ <head>
85
+ <title>Image Caption Generator</title>
86
+ <style>
87
+ body {
88
+ font-family: Arial, sans-serif;
89
+ max-width: 650px;
90
+ margin: 40px auto;
91
+ padding: 20px;
92
+ background: #fafafa;
93
+ }
94
+ h2 {
95
+ text-align: center;
96
+ }
97
+ #preview {
98
+ width: 100%;
99
+ margin-top: 15px;
100
+ display: none;
101
+ border-radius: 8px;
102
+ }
103
+ #captionBox {
104
+ margin-top: 20px;
105
+ padding: 15px;
106
+ background: #eee;
107
+ border-radius: 6px;
108
+ display: none;
109
+ }
110
+ button {
111
+ padding: 12px 20px;
112
+ margin-top: 10px;
113
+ width: 100%;
114
+ background: #4A90E2;
115
+ color: white;
116
+ font-size: 16px;
117
+ border: none;
118
+ border-radius: 6px;
119
+ cursor: pointer;
120
+ }
121
+ button:hover {
122
+ background: #357ABD;
123
+ }
124
+ </style>
125
+ </head>
126
+
127
+ <body>
128
+ <h2>Image to Caption Generator</h2>
129
+
130
+ <input type="file" id="imageInput" accept="image/*">
131
+
132
+ <img id="preview">
133
+
134
+ <button onclick="generateCaption()">Generate Caption</button>
135
+
136
+ <div id="captionBox"></div>
137
+
138
+ <script>
139
+ const imgInput = document.getElementById("imageInput");
140
+ const preview = document.getElementById("preview");
141
+ const captionBox = document.getElementById("captionBox");
142
+
143
+ imgInput.onchange = () => {
144
+ const file = imgInput.files[0];
145
+ if (file) {
146
+ preview.src = URL.createObjectURL(file);
147
+ preview.style.display = "block";
148
+ }
149
+ };
150
+
151
+ async function generateCaption() {
152
+ const file = imgInput.files[0];
153
+ if (!file) {
154
+ alert("Please upload an image.");
155
+ return;
156
+ }
157
+
158
+ const formData = new FormData();
159
+ formData.append("file", file);
160
+
161
+ captionBox.style.display = "block";
162
+ captionBox.innerHTML = "Generating caption...";
163
+
164
+ const response = await fetch("/img2caption", {
165
+ method: "POST",
166
+ body: formData
167
+ });
168
+
169
+ const result = await response.json();
170
+
171
+ captionBox.innerHTML = result.caption || result.error;
172
+ }
173
+ </script>
174
+
175
+ </body>
176
+ </html>
177
+ """
178
+
179
+
180
+ # -------------------------
181
+ # Keep HF Space alive
182
+ # -------------------------
183
+
184
+ SPACE_URL = "https://YOUR-SPACE-NAME.hf.space/health"
185
+
186
+ def keep_alive():
187
+ while True:
188
+ try:
189
+ requests.get(SPACE_URL, timeout=5)
190
+ except:
191
+ pass
192
+ time.sleep(240)
193
+
194
+ threading.Thread(target=keep_alive, daemon=True).start()
195
+
196
+
197
+ @app.get("/health")
198
+ def health():
199
+ return {"status": "ok"}