samuelolubukun commited on
Commit
406e5bf
Β·
verified Β·
1 Parent(s): cd13369

Upload 3 files

Browse files
Files changed (3) hide show
  1. Dockerfile +35 -0
  2. app.py +215 -0
  3. requirements.txt +8 -0
Dockerfile ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ── Base ──────────────────────────────────────────────────────────────────────
2
+ # SmolVLM is small enough to run on CPU, but a GPU Space is faster.
3
+ # HuggingFace Spaces requires the app to listen on port 7860.
4
+ FROM python:3.11-slim
5
+
6
+ # ── System deps ───────────────────────────────────────────────────────────────
7
+ RUN apt-get update && apt-get install -y --no-install-recommends \
8
+ git \
9
+ libgl1 \
10
+ libglib2.0-0 \
11
+ && rm -rf /var/lib/apt/lists/*
12
+
13
+ # ── Non-root user (HF Spaces runs as UID 1000) ────────────────────────────────
14
+ RUN useradd -m -u 1000 appuser
15
+ WORKDIR /app
16
+
17
+ # ── Install Python deps ───────────────────────────────────────────────────────
18
+ COPY requirements.txt .
19
+ RUN pip install --no-cache-dir --upgrade pip \
20
+ && pip install --no-cache-dir -r requirements.txt
21
+
22
+ # ── Copy app source ───────────────────────────────────────────────────────────
23
+ COPY app.py .
24
+
25
+ # ── HuggingFace cache (model weights downloaded at first startup) ─────────────
26
+ ENV HF_HOME=/app/.cache/huggingface
27
+ RUN mkdir -p /app/.cache/huggingface && chown -R appuser:appuser /app
28
+
29
+ USER appuser
30
+
31
+ # ── Port ──────────────────────────────────────────────────────────────────────
32
+ EXPOSE 7860
33
+
34
+ # ── Start ─────────────────────────────────────────────────────────────────────
35
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ FastAPI app for HuggingFaceTB/SmolVLM-Instruct
3
+ Supports: text-only prompts, single image, and multi-image inputs.
4
+ Port: 7860 (HuggingFace Spaces default)
5
+ """
6
+
7
+ import io
8
+ import base64
9
+ import logging
10
+ from contextlib import asynccontextmanager
11
+ from typing import Optional
12
+
13
+ import torch
14
+ from PIL import Image
15
+ from fastapi import FastAPI, File, Form, UploadFile, HTTPException
16
+ from pydantic import BaseModel
17
+ from transformers import AutoProcessor, AutoModelForVision2Seq
18
+
19
+ # ── Logging ───────────────────────────────────────────────────────────────────
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = logging.getLogger(__name__)
22
+
23
+ # ── Config ────────────────────────────────────────────────────────────────────
24
+ MODEL_ID = "HuggingFaceTB/SmolVLM-Instruct"
25
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
26
+ DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
27
+
28
+ # ── Globals ───────────────────────────────────────────────────────────────────
29
+ model = None
30
+ processor = None
31
+
32
+
33
+ @asynccontextmanager
34
+ async def lifespan(app: FastAPI):
35
+ global model, processor
36
+ logger.info(f"Loading {MODEL_ID} on {DEVICE} …")
37
+ processor = AutoProcessor.from_pretrained(MODEL_ID)
38
+ model = AutoModelForVision2Seq.from_pretrained(
39
+ MODEL_ID,
40
+ torch_dtype=DTYPE,
41
+ _attn_implementation="eager", # swap to "flash_attention_2" on supported GPUs
42
+ ).to(DEVICE)
43
+ model.eval()
44
+ logger.info("SmolVLM ready βœ“")
45
+ yield
46
+ del model, processor
47
+ if DEVICE == "cuda":
48
+ torch.cuda.empty_cache()
49
+
50
+
51
+ # ── App ───────────────────────────────────────────────────────────────────────
52
+ app = FastAPI(
53
+ title="SmolVLM API",
54
+ description="Multimodal inference with HuggingFaceTB/SmolVLM-Instruct",
55
+ version="1.0.0",
56
+ lifespan=lifespan,
57
+ )
58
+
59
+
60
+ # ── Helpers ───────────────────────────────────────────────────────────────────
61
+
62
+ def run_inference(
63
+ prompt: str,
64
+ images: Optional[list[Image.Image]] = None,
65
+ max_new_tokens: int = 512,
66
+ temperature: float = 0.0,
67
+ ) -> str:
68
+ images = images or []
69
+
70
+ # Build chat message β€” SmolVLM uses the standard messages format
71
+ content = []
72
+ for img in images:
73
+ content.append({"type": "image"})
74
+ content.append({"type": "text", "text": prompt})
75
+
76
+ messages = [{"role": "user", "content": content}]
77
+
78
+ # Apply chat template
79
+ text_input = processor.apply_chat_template(messages, add_generation_prompt=True)
80
+
81
+ inputs = processor(
82
+ text=text_input,
83
+ images=images if images else None,
84
+ return_tensors="pt",
85
+ ).to(DEVICE)
86
+
87
+ generate_kwargs = dict(
88
+ **inputs,
89
+ max_new_tokens=max_new_tokens,
90
+ do_sample=temperature > 0,
91
+ )
92
+ if temperature > 0:
93
+ generate_kwargs["temperature"] = temperature
94
+
95
+ with torch.no_grad():
96
+ output_ids = model.generate(**generate_kwargs)
97
+
98
+ # Decode only the new tokens
99
+ input_len = inputs["input_ids"].shape[1]
100
+ generated = output_ids[0][input_len:]
101
+ return processor.decode(generated, skip_special_tokens=True).strip()
102
+
103
+
104
+ # ── Routes ────────────────────────────────────────────────────────────────────
105
+
106
+ @app.get("/", tags=["Health"])
107
+ def root():
108
+ return {"status": "ok", "model": MODEL_ID, "device": DEVICE}
109
+
110
+
111
+ @app.get("/health", tags=["Health"])
112
+ def health():
113
+ return {"model_loaded": model is not None}
114
+
115
+
116
+ # ── 1. Text-only ──────────────────────────────────────────────────────────────
117
+
118
+ class TextRequest(BaseModel):
119
+ prompt: str
120
+ max_new_tokens: int = 512
121
+ temperature: float = 0.0
122
+
123
+
124
+ @app.post("/generate/text", tags=["Inference"])
125
+ def generate_text(req: TextRequest):
126
+ """Plain text prompt β€” no image required."""
127
+ if model is None:
128
+ raise HTTPException(503, "Model not loaded yet")
129
+ try:
130
+ return {"prompt": req.prompt, "response": run_inference(
131
+ req.prompt,
132
+ max_new_tokens=req.max_new_tokens,
133
+ temperature=req.temperature,
134
+ )}
135
+ except Exception as e:
136
+ logger.exception("Inference error")
137
+ raise HTTPException(500, str(e))
138
+
139
+
140
+ # ── 2. Image upload (multipart/form-data) ─────────────────────────────────────
141
+
142
+ @app.post("/generate/vision", tags=["Inference"])
143
+ async def generate_vision(
144
+ prompt: str = Form("Describe the image(s) in detail."),
145
+ max_new_tokens: int = Form(512),
146
+ temperature: float = Form(0.0),
147
+ images: list[UploadFile] = File(default=[]),
148
+ ):
149
+ """Upload one or more images with an optional text prompt."""
150
+ if model is None:
151
+ raise HTTPException(503, "Model not loaded yet")
152
+
153
+ pil_images: list[Image.Image] = []
154
+ for upload in images:
155
+ raw = await upload.read()
156
+ try:
157
+ pil_images.append(Image.open(io.BytesIO(raw)).convert("RGB"))
158
+ except Exception:
159
+ raise HTTPException(400, f"Could not decode image: {upload.filename}")
160
+
161
+ try:
162
+ response = run_inference(
163
+ prompt,
164
+ images=pil_images or None,
165
+ max_new_tokens=max_new_tokens,
166
+ temperature=temperature,
167
+ )
168
+ return {"prompt": prompt, "num_images": len(pil_images), "response": response}
169
+ except Exception as e:
170
+ logger.exception("Inference error")
171
+ raise HTTPException(500, str(e))
172
+
173
+
174
+ # ── 3. Base64 images via JSON ─────────────────────────────────────────────────
175
+
176
+ class VisionB64Request(BaseModel):
177
+ prompt: str = "Describe the image(s) in detail."
178
+ images_b64: list[str] = []
179
+ max_new_tokens: int = 512
180
+ temperature: float = 0.0
181
+
182
+
183
+ @app.post("/generate/vision/base64", tags=["Inference"])
184
+ def generate_vision_b64(req: VisionB64Request):
185
+ """Send base64-encoded images inside a JSON body."""
186
+ if model is None:
187
+ raise HTTPException(503, "Model not loaded yet")
188
+
189
+ pil_images: list[Image.Image] = []
190
+ for idx, b64str in enumerate(req.images_b64):
191
+ if "," in b64str:
192
+ b64str = b64str.split(",", 1)[1]
193
+ try:
194
+ raw = base64.b64decode(b64str)
195
+ pil_images.append(Image.open(io.BytesIO(raw)).convert("RGB"))
196
+ except Exception:
197
+ raise HTTPException(400, f"Could not decode base64 image at index {idx}")
198
+
199
+ try:
200
+ response = run_inference(
201
+ req.prompt,
202
+ images=pil_images or None,
203
+ max_new_tokens=req.max_new_tokens,
204
+ temperature=req.temperature,
205
+ )
206
+ return {"prompt": req.prompt, "num_images": len(pil_images), "response": response}
207
+ except Exception as e:
208
+ logger.exception("Inference error")
209
+ raise HTTPException(500, str(e))
210
+
211
+
212
+ # ── Entry point ───────────────────────────────────────────────────────────────
213
+ if __name__ == "__main__":
214
+ import uvicorn
215
+ uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=False)
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.115.0
2
+ uvicorn[standard]==0.30.6
3
+ python-multipart==0.0.9
4
+ pillow==10.4.0
5
+ torch==2.4.1
6
+ torchvision==0.19.1
7
+ transformers==4.45.0
8
+ accelerate==0.34.2