reygml commited on
Commit
8b95053
·
1 Parent(s): 69f3c04

initial commit

Browse files
Files changed (4) hide show
  1. Dockerfile +19 -0
  2. app.py +98 -0
  3. requirements.txt +23 -0
  4. util.py +115 -0
Dockerfile ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Read the doc: https://huggingface.co/docs/hub/spaces-sdks-docker
2
+ # you will also find guides on how best to write your Dockerfile
3
+
4
+ FROM python:3.9
5
+
6
+ RUN useradd -m -u 1000 user
7
+ USER user
8
+ ENV PATH="/home/user/.local/bin:$PATH"
9
+
10
+ WORKDIR /app
11
+
12
+ RUN pip install --no-deps "https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.0.post2/flash_attn-2.7.0.post2+cu12torch2.4cxx11abiFALSE-cp310-cp310-linux_x86_64.whl"
13
+
14
+ COPY --chown=user ./requirements.txt requirements.txt
15
+ RUN pip install --no-cache-dir --upgrade -r requirements.txt
16
+
17
+ COPY --chown=user . /app
18
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
19
+
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import asyncio
3
+ from typing import List, Optional
4
+
5
+ from fastapi import FastAPI, UploadFile, File, Form, HTTPException
6
+ from pydantic import BaseModel, Field, HttpUrl
7
+ import uvicorn
8
+
9
+ from util import get_runner, SmolVLMRunner
10
+
11
+
12
+ app = FastAPI(title="SmolVLM Inference API", version="1.0.0")
13
+ _runner: Optional[SmolVLMRunner] = None
14
+
15
+
16
+ class URLRequest(BaseModel):
17
+ prompt: str = Field(..., description="Text prompt to accompany the images.")
18
+ image_urls: List[HttpUrl] = Field(..., description="List of image URLs.")
19
+ max_new_tokens: int = Field(300, ge=1, le=1024)
20
+ temperature: Optional[float] = Field(None, ge=0.0, le=2.0)
21
+ top_p: Optional[float] = Field(None, gt=0.0, le=1.0)
22
+
23
+
24
+ @app.on_event("startup")
25
+ async def _load_model_on_startup():
26
+ global _runner
27
+ _runner = get_runner()
28
+
29
+
30
+ @app.get("/")
31
+ def health():
32
+ return {"status": "ok", "model": _runner.model_id if _runner else None}
33
+
34
+
35
+ @app.post("/generate")
36
+ async def generate_from_files(
37
+ prompt: str = Form(...),
38
+ images: List[UploadFile] = File(..., description="One or more image files."),
39
+ max_new_tokens: int = Form(300),
40
+ temperature: Optional[float] = Form(None),
41
+ top_p: Optional[float] = Form(None),
42
+ ):
43
+ """
44
+ Multipart form endpoint:
45
+ - prompt: str
46
+ - images: one or more image files (image/*)
47
+ """
48
+ if not images:
49
+ raise HTTPException(status_code=400, detail="At least one image must be provided.")
50
+
51
+ # Read all files into memory (simple & fine for moderate sizes)
52
+ blobs = []
53
+ for f in images:
54
+ if not f.content_type or not f.content_type.startswith("image/"):
55
+ raise HTTPException(status_code=415, detail=f"Unsupported file type: {f.content_type}")
56
+ blobs.append(await f.read())
57
+
58
+ pil_images = _runner.load_pil_from_bytes(blobs)
59
+ text = _runner.generate(
60
+ prompt=prompt,
61
+ images=pil_images,
62
+ max_new_tokens=max_new_tokens,
63
+ temperature=temperature,
64
+ top_p=top_p,
65
+ )
66
+ return {"text": text}
67
+
68
+
69
+ @app.post("/generate_urls")
70
+ async def generate_from_urls(req: URLRequest):
71
+ """
72
+ JSON endpoint:
73
+ {
74
+ "prompt": "...",
75
+ "image_urls": ["https://...","https://..."],
76
+ "max_new_tokens": 300,
77
+ "temperature": 0.2,
78
+ "top_p": 0.95
79
+ }
80
+ """
81
+ if len(req.image_urls) == 0:
82
+ raise HTTPException(status_code=400, detail="At least one image URL is required.")
83
+
84
+ pil_images = _runner.load_pil_from_urls([str(u) for u in req.image_urls])
85
+ text = _runner.generate(
86
+ prompt=req.prompt,
87
+ images=pil_images,
88
+ max_new_tokens=req.max_new_tokens,
89
+ temperature=req.temperature,
90
+ top_p=req.top_p,
91
+ )
92
+ return {"text": text}
93
+
94
+
95
+ if __name__ == "__main__":
96
+ # Run with: python app.py (or: uvicorn app:app --host 0.0.0.0 --port 8000)
97
+ uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=False)
98
+
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+
3
+ fastapi
4
+ uvicorn[standard]
5
+ torch==2.4.0
6
+ torchvision==0.19.0
7
+ pillow==10.4.0
8
+ imageio==2.36.1
9
+ imageio-ffmpeg==0.5.1
10
+ accelerate
11
+ diffusers
12
+ peft
13
+ sentencepiece
14
+ bitsandbytes
15
+ gguf
16
+ pypdfium2
17
+ icecream
18
+ einops
19
+ Pillow
20
+ gradio
21
+ xformers==0.0.27.post2
22
+ spconv-cu120==2.3.6
23
+ transformers==4.46.3
util.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # util.py
2
+ import os
3
+ import threading
4
+ from io import BytesIO
5
+ from typing import List, Sequence, Union
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from transformers import AutoProcessor, AutoModelForVision2Seq
10
+ from transformers.image_utils import load_image as hf_load_image
11
+
12
+
13
+ class SmolVLMRunner:
14
+ """
15
+ Thin wrapper around HuggingFaceTB/SmolVLM-Instruct for single/multi-image VQA or captioning.
16
+ Reuses a single model instance across calls and serializes inference with a lock (GPU friendly).
17
+ """
18
+
19
+ def __init__(self, model_id: str | None = None, device: str | None = None):
20
+ self.model_id = model_id or os.getenv("SMOLVLM_MODEL_ID", "HuggingFaceTB/SmolVLM-Instruct")
21
+ self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
22
+ self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
23
+
24
+ self.processor = AutoProcessor.from_pretrained(self.model_id)
25
+
26
+ attn_impl = "flash_attention_2" if self.device == "cuda" else "eager"
27
+ try:
28
+ self.model = AutoModelForVision2Seq.from_pretrained(
29
+ self.model_id,
30
+ torch_dtype=self.dtype,
31
+ _attn_implementation=attn_impl,
32
+ ).to(self.device)
33
+ except Exception:
34
+ # Fallback if flash-attn isn't available
35
+ self.model = AutoModelForVision2Seq.from_pretrained(
36
+ self.model_id,
37
+ torch_dtype=self.dtype,
38
+ _attn_implementation="eager",
39
+ ).to(self.device)
40
+
41
+ self.model.eval()
42
+ self._lock = threading.Lock()
43
+
44
+ # ---------- Image loading helpers ----------
45
+
46
+ @staticmethod
47
+ def _ensure_rgb(img: Image.Image) -> Image.Image:
48
+ return img.convert("RGB") if img.mode != "RGB" else img
49
+
50
+ @classmethod
51
+ def load_pil_from_urls(cls, urls: Sequence[str]) -> List[Image.Image]:
52
+ """Load images from HTTP/HTTPS URLs using HF's helper."""
53
+ images: List[Image.Image] = []
54
+ for u in urls:
55
+ img = hf_load_image(u)
56
+ images.append(cls._ensure_rgb(img))
57
+ return images
58
+
59
+ @classmethod
60
+ def load_pil_from_bytes(cls, blobs: Sequence[bytes]) -> List[Image.Image]:
61
+ """Load images from raw bytes (e.g., FastAPI uploads)."""
62
+ images: List[Image.Image] = []
63
+ for b in blobs:
64
+ img = Image.open(BytesIO(b))
65
+ images.append(cls._ensure_rgb(img))
66
+ return images
67
+
68
+ # ---------- Core inference ----------
69
+
70
+ def generate(
71
+ self,
72
+ prompt: str,
73
+ images: Sequence[Image.Image],
74
+ max_new_tokens: int = 300,
75
+ temperature: float | None = None,
76
+ top_p: float | None = None,
77
+ ) -> str:
78
+ """
79
+ Run generation with 0+ images (text-only works too).
80
+ """
81
+ # Build chat template: one "image" token per provided image, then the text.
82
+ content = [{"type": "image"} for _ in images] + [{"type": "text", "text": prompt}]
83
+ messages = [{"role": "user", "content": content}]
84
+
85
+ chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
86
+
87
+ inputs = self.processor(text=chat_prompt, images=list(images), return_tensors="pt")
88
+ inputs = {k: v.to(self.device) if hasattr(v, "to") else v for k, v in inputs.items()}
89
+
90
+ gen_kwargs = dict(max_new_tokens=max_new_tokens)
91
+ if temperature is not None:
92
+ gen_kwargs["temperature"] = float(temperature)
93
+ if top_p is not None:
94
+ gen_kwargs["top_p"] = float(top_p)
95
+
96
+ with self._lock, torch.inference_mode():
97
+ generated_ids = self.model.generate(**inputs, **gen_kwargs)
98
+
99
+ text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
100
+ # Many chat templates prepend "Assistant: "
101
+ if text.startswith("Assistant:"):
102
+ text = text[len("Assistant:") :].strip()
103
+ return text
104
+
105
+
106
+ # Convenience singleton (optional import path)
107
+ _runner_singleton: SmolVLMRunner | None = None
108
+
109
+
110
+ def get_runner() -> SmolVLMRunner:
111
+ global _runner_singleton
112
+ if _runner_singleton is None:
113
+ _runner_singleton = SmolVLMRunner()
114
+ return _runner_singleton
115
+