Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -4,17 +4,19 @@ import threading
|
|
| 4 |
import time
|
| 5 |
from fastapi import FastAPI, File, UploadFile
|
| 6 |
from fastapi.responses import JSONResponse, HTMLResponse
|
| 7 |
-
from fastapi.staticfiles import StaticFiles
|
| 8 |
from PIL import Image
|
| 9 |
import torch
|
| 10 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 11 |
import requests
|
| 12 |
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
| 14 |
|
| 15 |
-
#
|
| 16 |
-
# Load
|
| 17 |
-
#
|
| 18 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 19 |
|
| 20 |
processor = AutoProcessor.from_pretrained(
|
|
@@ -27,10 +29,12 @@ model = AutoModelForCausalLM.from_pretrained(
|
|
| 27 |
trust_remote_code=True
|
| 28 |
).to(device).eval()
|
| 29 |
|
|
|
|
| 30 |
inference_lock = asyncio.Lock()
|
| 31 |
|
| 32 |
|
| 33 |
def caption_image(image: Image.Image) -> str:
|
|
|
|
| 34 |
inputs = processor(
|
| 35 |
text="<MORE_DETAILED_CAPTION>",
|
| 36 |
images=image,
|
|
@@ -44,7 +48,9 @@ def caption_image(image: Image.Image) -> str:
|
|
| 44 |
num_beams=3
|
| 45 |
)
|
| 46 |
|
| 47 |
-
decoded = processor.batch_decode(
|
|
|
|
|
|
|
| 48 |
|
| 49 |
parsed = processor.post_process_generation(
|
| 50 |
decoded,
|
|
@@ -55,45 +61,47 @@ def caption_image(image: Image.Image) -> str:
|
|
| 55 |
return parsed["<MORE_DETAILED_CAPTION>"]
|
| 56 |
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
# API
|
| 60 |
-
#
|
| 61 |
@app.post("/img2caption")
|
| 62 |
async def img2caption(file: UploadFile = File(...)):
|
| 63 |
try:
|
| 64 |
data = await file.read()
|
| 65 |
image = Image.open(io.BytesIO(data)).convert("RGB")
|
| 66 |
|
|
|
|
| 67 |
async with inference_lock:
|
| 68 |
caption = caption_image(image)
|
| 69 |
|
| 70 |
return {"caption": caption}
|
| 71 |
|
| 72 |
except Exception as e:
|
| 73 |
-
return JSONResponse(
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
-
#
|
| 77 |
-
# HTML
|
| 78 |
-
#
|
| 79 |
@app.get("/", response_class=HTMLResponse)
|
| 80 |
def ui():
|
| 81 |
return """
|
| 82 |
<!DOCTYPE html>
|
| 83 |
<html>
|
| 84 |
<head>
|
| 85 |
-
<title>Image
|
| 86 |
<style>
|
| 87 |
body {
|
| 88 |
-
font-family: Arial
|
| 89 |
max-width: 650px;
|
| 90 |
margin: 40px auto;
|
| 91 |
padding: 20px;
|
| 92 |
background: #fafafa;
|
| 93 |
}
|
| 94 |
-
h2 {
|
| 95 |
-
text-align: center;
|
| 96 |
-
}
|
| 97 |
#preview {
|
| 98 |
width: 100%;
|
| 99 |
margin-top: 15px;
|
|
@@ -108,27 +116,26 @@ def ui():
|
|
| 108 |
display: none;
|
| 109 |
}
|
| 110 |
button {
|
| 111 |
-
padding: 12px
|
| 112 |
margin-top: 10px;
|
| 113 |
width: 100%;
|
| 114 |
-
background: #
|
| 115 |
color: white;
|
| 116 |
-
font-size: 16px;
|
| 117 |
border: none;
|
| 118 |
border-radius: 6px;
|
| 119 |
cursor: pointer;
|
|
|
|
| 120 |
}
|
| 121 |
button:hover {
|
| 122 |
-
background: #
|
| 123 |
}
|
| 124 |
</style>
|
| 125 |
</head>
|
| 126 |
|
| 127 |
<body>
|
| 128 |
-
<h2>Image
|
| 129 |
|
| 130 |
<input type="file" id="imageInput" accept="image/*">
|
| 131 |
-
|
| 132 |
<img id="preview">
|
| 133 |
|
| 134 |
<button onclick="generateCaption()">Generate Caption</button>
|
|
@@ -151,24 +158,23 @@ def ui():
|
|
| 151 |
async function generateCaption() {
|
| 152 |
const file = imgInput.files[0];
|
| 153 |
if (!file) {
|
| 154 |
-
alert("
|
| 155 |
return;
|
| 156 |
}
|
| 157 |
|
| 158 |
-
const
|
| 159 |
-
|
| 160 |
|
| 161 |
captionBox.style.display = "block";
|
| 162 |
captionBox.innerHTML = "Generating caption...";
|
| 163 |
|
| 164 |
-
const
|
| 165 |
method: "POST",
|
| 166 |
-
body:
|
| 167 |
});
|
| 168 |
|
| 169 |
-
const
|
| 170 |
-
|
| 171 |
-
captionBox.innerHTML = result.caption || result.error;
|
| 172 |
}
|
| 173 |
</script>
|
| 174 |
|
|
@@ -177,17 +183,15 @@ def ui():
|
|
| 177 |
"""
|
| 178 |
|
| 179 |
|
| 180 |
-
#
|
| 181 |
-
# Keep HF
|
| 182 |
-
#
|
| 183 |
-
|
| 184 |
-
SPACE_URL = "https://YOUR-SPACE-NAME.hf.space/health"
|
| 185 |
|
| 186 |
def keep_alive():
|
| 187 |
while True:
|
| 188 |
try:
|
| 189 |
requests.get(SPACE_URL, timeout=5)
|
| 190 |
-
except:
|
| 191 |
pass
|
| 192 |
time.sleep(240)
|
| 193 |
|
|
|
|
| 4 |
import time
|
| 5 |
from fastapi import FastAPI, File, UploadFile
|
| 6 |
from fastapi.responses import JSONResponse, HTMLResponse
|
|
|
|
| 7 |
from PIL import Image
|
| 8 |
import torch
|
| 9 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
| 10 |
import requests
|
| 11 |
|
| 12 |
+
# ---------------------------------------------------
|
| 13 |
+
# FastAPI application
|
| 14 |
+
# ---------------------------------------------------
|
| 15 |
+
app = FastAPI(title="Florence Image Caption API")
|
| 16 |
|
| 17 |
+
# ---------------------------------------------------
|
| 18 |
+
# Load model once
|
| 19 |
+
# ---------------------------------------------------
|
| 20 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
|
| 22 |
processor = AutoProcessor.from_pretrained(
|
|
|
|
| 29 |
trust_remote_code=True
|
| 30 |
).to(device).eval()
|
| 31 |
|
| 32 |
+
# Concurrency lock so HF Spaces doesn't crash under load
|
| 33 |
inference_lock = asyncio.Lock()
|
| 34 |
|
| 35 |
|
| 36 |
def caption_image(image: Image.Image) -> str:
|
| 37 |
+
"""Generate caption for a single image."""
|
| 38 |
inputs = processor(
|
| 39 |
text="<MORE_DETAILED_CAPTION>",
|
| 40 |
images=image,
|
|
|
|
| 48 |
num_beams=3
|
| 49 |
)
|
| 50 |
|
| 51 |
+
decoded = processor.batch_decode(
|
| 52 |
+
output_ids, skip_special_tokens=False
|
| 53 |
+
)[0]
|
| 54 |
|
| 55 |
parsed = processor.post_process_generation(
|
| 56 |
decoded,
|
|
|
|
| 61 |
return parsed["<MORE_DETAILED_CAPTION>"]
|
| 62 |
|
| 63 |
|
| 64 |
+
# ---------------------------------------------------
|
| 65 |
+
# API endpoint
|
| 66 |
+
# ---------------------------------------------------
|
| 67 |
@app.post("/img2caption")
|
| 68 |
async def img2caption(file: UploadFile = File(...)):
|
| 69 |
try:
|
| 70 |
data = await file.read()
|
| 71 |
image = Image.open(io.BytesIO(data)).convert("RGB")
|
| 72 |
|
| 73 |
+
# Protect GPU inference
|
| 74 |
async with inference_lock:
|
| 75 |
caption = caption_image(image)
|
| 76 |
|
| 77 |
return {"caption": caption}
|
| 78 |
|
| 79 |
except Exception as e:
|
| 80 |
+
return JSONResponse(
|
| 81 |
+
{"error": str(e)},
|
| 82 |
+
status_code=500
|
| 83 |
+
)
|
| 84 |
|
| 85 |
|
| 86 |
+
# ---------------------------------------------------
|
| 87 |
+
# Custom UI (HTML + CSS + JS)
|
| 88 |
+
# ---------------------------------------------------
|
| 89 |
@app.get("/", response_class=HTMLResponse)
|
| 90 |
def ui():
|
| 91 |
return """
|
| 92 |
<!DOCTYPE html>
|
| 93 |
<html>
|
| 94 |
<head>
|
| 95 |
+
<title>Florence Image Captioning</title>
|
| 96 |
<style>
|
| 97 |
body {
|
| 98 |
+
font-family: Arial;
|
| 99 |
max-width: 650px;
|
| 100 |
margin: 40px auto;
|
| 101 |
padding: 20px;
|
| 102 |
background: #fafafa;
|
| 103 |
}
|
| 104 |
+
h2 { text-align: center; }
|
|
|
|
|
|
|
| 105 |
#preview {
|
| 106 |
width: 100%;
|
| 107 |
margin-top: 15px;
|
|
|
|
| 116 |
display: none;
|
| 117 |
}
|
| 118 |
button {
|
| 119 |
+
padding: 12px;
|
| 120 |
margin-top: 10px;
|
| 121 |
width: 100%;
|
| 122 |
+
background: #4a90e2;
|
| 123 |
color: white;
|
|
|
|
| 124 |
border: none;
|
| 125 |
border-radius: 6px;
|
| 126 |
cursor: pointer;
|
| 127 |
+
font-size: 16px;
|
| 128 |
}
|
| 129 |
button:hover {
|
| 130 |
+
background: #357abd;
|
| 131 |
}
|
| 132 |
</style>
|
| 133 |
</head>
|
| 134 |
|
| 135 |
<body>
|
| 136 |
+
<h2>Image Caption Generator</h2>
|
| 137 |
|
| 138 |
<input type="file" id="imageInput" accept="image/*">
|
|
|
|
| 139 |
<img id="preview">
|
| 140 |
|
| 141 |
<button onclick="generateCaption()">Generate Caption</button>
|
|
|
|
| 158 |
async function generateCaption() {
|
| 159 |
const file = imgInput.files[0];
|
| 160 |
if (!file) {
|
| 161 |
+
alert("Upload an image first");
|
| 162 |
return;
|
| 163 |
}
|
| 164 |
|
| 165 |
+
const form = new FormData();
|
| 166 |
+
form.append("file", file);
|
| 167 |
|
| 168 |
captionBox.style.display = "block";
|
| 169 |
captionBox.innerHTML = "Generating caption...";
|
| 170 |
|
| 171 |
+
const res = await fetch("/img2caption", {
|
| 172 |
method: "POST",
|
| 173 |
+
body: form
|
| 174 |
});
|
| 175 |
|
| 176 |
+
const data = await res.json();
|
| 177 |
+
captionBox.innerHTML = data.caption || data.error;
|
|
|
|
| 178 |
}
|
| 179 |
</script>
|
| 180 |
|
|
|
|
| 183 |
"""
|
| 184 |
|
| 185 |
|
| 186 |
+
# ---------------------------------------------------
|
| 187 |
+
# Keep-alive system to prevent HF auto-sleep
|
| 188 |
+
# ---------------------------------------------------
|
|
|
|
|
|
|
| 189 |
|
| 190 |
def keep_alive():
|
| 191 |
while True:
|
| 192 |
try:
|
| 193 |
requests.get(SPACE_URL, timeout=5)
|
| 194 |
+
except Exception:
|
| 195 |
pass
|
| 196 |
time.sleep(240)
|
| 197 |
|