Spaces:
Sleeping
Sleeping
| # util.py | |
| import os | |
| from pathlib import Path | |
| # Put every cache under /tmp (always writable in Spaces) | |
| CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf-cache") | |
| Path(CACHE_DIR).mkdir(parents=True, exist_ok=True) | |
| # Make sure libraries don't fall back to "~/.cache" -> "/.cache" | |
| os.environ.setdefault("HF_HOME", CACHE_DIR) | |
| os.environ.setdefault("TRANSFORMERS_CACHE", CACHE_DIR) | |
| os.environ.setdefault("HUGGINGFACE_HUB_CACHE", CACHE_DIR) | |
| os.environ.setdefault("XDG_CACHE_HOME", CACHE_DIR) | |
| os.environ.setdefault("TORCH_HOME", CACHE_DIR) | |
| from time import perf_counter | |
| import threading | |
| from io import BytesIO | |
| from typing import List, Sequence, Tuple, Dict, Any | |
| import io | |
| import base64 | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from transformers.image_utils import load_image as hf_load_image | |
| from grounding_dino2 import get_runner as get_gdino_runner, visualize_detections | |
| def _has_flash_attn() -> bool: | |
| try: | |
| import flash_attn # noqa: F401 | |
| return True | |
| except Exception: | |
| return False | |
| def _pick_backend_and_dtype(): | |
| if not torch.cuda.is_available(): | |
| return "eager", torch.float32, "cpu" | |
| major, _ = torch.cuda.get_device_capability() | |
| dev = "cuda" | |
| bf16_ok = torch.cuda.is_bf16_supported() | |
| dtype = torch.bfloat16 if bf16_ok else torch.float16 | |
| if major >= 8: # Ampere+ | |
| attn = "flash_attention_2" if _has_flash_attn() else "eager" | |
| else: | |
| attn = "eager" | |
| return attn, dtype, dev | |
| class SmolVLMRunner: | |
| """Portable wrapper with per-call metrics.""" | |
| def __init__(self, model_id: str | None = None, device: str | None = None): | |
| self.model_id = model_id or os.getenv("SMOLVLM_MODEL_ID", "HuggingFaceTB/SmolVLM-Instruct") | |
| attn_impl, dtype, dev = _pick_backend_and_dtype() | |
| attn_impl = os.getenv("SMOLVLM_ATTN", attn_impl) # optional override | |
| self.device = device or dev | |
| self.dtype = dtype | |
| self.attn_impl = attn_impl | |
| if self.device == "cuda" and self.attn_impl == "sdpa": | |
| try: | |
| from torch.backends.cuda import sdp_kernel | |
| sdp_kernel(enable_flash=False, enable_mem_efficient=True, enable_math=True) | |
| except Exception: | |
| pass | |
| self.processor = AutoProcessor.from_pretrained(self.model_id, cache_dir=CACHE_DIR) | |
| self.model = AutoModelForVision2Seq.from_pretrained( | |
| self.model_id, | |
| torch_dtype=self.dtype, | |
| _attn_implementation=self.attn_impl, | |
| cache_dir=CACHE_DIR, | |
| ).to(self.device) | |
| try: | |
| self.model.config._attn_implementation = self.attn_impl | |
| except Exception: | |
| pass | |
| self.model.eval() | |
| self._lock = threading.Lock() | |
| # ---------- Image utils ---------- | |
| def _ensure_rgb(img: Image.Image) -> Image.Image: | |
| return img.convert("RGB") if img.mode != "RGB" else img | |
| def load_pil_from_urls(cls, urls: Sequence[str]) -> List[Image.Image]: | |
| return [cls._ensure_rgb(hf_load_image(u)) for u in urls] | |
| def load_pil_from_bytes(cls, blobs: Sequence[bytes]) -> List[Image.Image]: | |
| return [cls._ensure_rgb(Image.open(BytesIO(b))) for b in blobs] | |
| # ---------- Inference ---------- | |
| def detect_and_describe( | |
| self, | |
| image: Image.Image, | |
| labels: list[str] | str, | |
| *, | |
| box_threshold: float = 0.4, | |
| text_threshold: float = 0.3, | |
| pad_frac: float = 0.06, | |
| max_new_tokens: int = 160, | |
| temperature: float | None = None, | |
| top_p: float | None = None, | |
| return_overlay: bool = False, | |
| ) -> list[dict] | dict: | |
| """ | |
| Uses Grounding DINO to detect boxes for `labels`, then asks SmolVLM to | |
| describe each cropped box. | |
| If return_overlay=False (default): returns a list of dicts: | |
| [{ 'label','score','box_xyxy','description' }, ...] | |
| If return_overlay=True: returns a dict: | |
| { 'detections': [...], 'overlay_png_b64': '<base64 PNG>' } | |
| """ | |
| gdino = get_gdino_runner() | |
| detections = gdino.detect( | |
| image=image, | |
| labels=labels, | |
| box_threshold=box_threshold, | |
| text_threshold=text_threshold, | |
| pad_frac=pad_frac, | |
| ) | |
| if not detections: | |
| return [] if not return_overlay else {"detections": [], "overlay_png_b64": None} | |
| results: list[dict] = [] | |
| for det in detections: | |
| crop = det["crop"] | |
| prompt_txt = f"The image gets the label: '{det['label']}'. Describe the object inside this crop in detail." | |
| content = [{"type": "image"}, {"type": "text", "text": prompt_txt}] | |
| messages = [{"role": "user", "content": content}] | |
| chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) | |
| inputs = self.processor(text=chat_prompt, images=[crop], return_tensors="pt") | |
| inputs = {k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| gen_kwargs = dict(max_new_tokens=max_new_tokens) | |
| if temperature is not None: | |
| gen_kwargs["temperature"] = float(temperature) | |
| if top_p is not None: | |
| gen_kwargs["top_p"] = float(top_p) | |
| with self._lock, torch.inference_mode(): | |
| out_ids = self.model.generate(**inputs, **gen_kwargs) | |
| text = self.processor.batch_decode(out_ids, skip_special_tokens=True)[0].strip() | |
| if text.startswith("Assistant:"): | |
| text = text[len("Assistant:"):].strip() | |
| results.append({ | |
| "label": det["label"], | |
| "score": det["score"], | |
| "box_xyxy": det["box_xyxy"], | |
| "description": text, | |
| }) | |
| if not return_overlay: | |
| return results | |
| # Build overlay image (PNG -> base64 string) | |
| overlay = visualize_detections(image, detections) | |
| buf = io.BytesIO() | |
| overlay.save(buf, format="PNG") | |
| b64 = base64.b64encode(buf.getvalue()).decode("ascii") | |
| return {"detections": results, "overlay_png_b64": b64} | |
| def generate( | |
| self, | |
| prompt: str, | |
| images: Sequence[Image.Image], | |
| max_new_tokens: int = 300, | |
| temperature: float | None = None, | |
| top_p: float | None = None, | |
| return_stats: bool = False, | |
| ) -> str | Tuple[str, Dict[str, Any]]: | |
| """ | |
| Returns str by default. | |
| If return_stats=True, returns (text, metrics_dict). | |
| """ | |
| meta = { | |
| "model_id": self.model_id, | |
| "device": self.device, | |
| "dtype": str(self.dtype).replace("torch.", ""), | |
| "attn_backend": self.attn_impl, | |
| "image_count": len(images), | |
| "max_new_tokens": int(max_new_tokens), | |
| "temperature": None if temperature is None else float(temperature), | |
| "top_p": None if top_p is None else float(top_p), | |
| } | |
| t0 = perf_counter() | |
| content = [{"type": "image"} for _ in images] + [{"type": "text", "text": prompt}] | |
| messages = [{"role": "user", "content": content}] | |
| chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True) | |
| # Preprocess (tokenize + vision) | |
| inputs = self.processor(text=chat_prompt, images=list(images), return_tensors="pt") | |
| inputs = {k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()} | |
| t_pre_end = perf_counter() | |
| # Inference (generate) | |
| gen_kwargs = dict(max_new_tokens=max_new_tokens) | |
| if temperature is not None: | |
| gen_kwargs["temperature"] = float(temperature) | |
| if top_p is not None: | |
| gen_kwargs["top_p"] = float(top_p) | |
| if self.device == "cuda": | |
| torch.cuda.synchronize() | |
| torch.cuda.reset_peak_memory_stats() | |
| with self._lock, torch.inference_mode(): | |
| t_inf_start = perf_counter() | |
| out_ids = self.model.generate(**inputs, **gen_kwargs) | |
| if self.device == "cuda": | |
| torch.cuda.synchronize() | |
| t_inf_end = perf_counter() | |
| # Decode | |
| text = self.processor.batch_decode(out_ids, skip_special_tokens=True)[0].strip() | |
| if text.startswith("Assistant:"): | |
| text = text[len("Assistant:"):].strip() | |
| t_dec_end = perf_counter() | |
| # Stats | |
| input_tokens = int(inputs["input_ids"].shape[-1]) if "input_ids" in inputs else None | |
| total_tokens = int(out_ids.shape[-1]) # includes prompt + generated | |
| output_tokens = int(total_tokens - (input_tokens or 0)) if input_tokens is not None else None | |
| pre_ms = (t_pre_end - t0) * 1000.0 | |
| infer_ms = (t_inf_end - t_inf_start) * 1000.0 | |
| decode_ms = (t_dec_end - t_inf_end) * 1000.0 | |
| total_ms = (t_dec_end - t0) * 1000.0 | |
| tps_infer = (output_tokens / ((t_inf_end - t_inf_start) + 1e-9)) if output_tokens else None | |
| tps_total = ( | |
| (output_tokens / ((t_dec_end - t0) + 1e-9)) if output_tokens else None | |
| ) | |
| gpu_mem_alloc_mb = gpu_mem_resv_mb = None | |
| gpu_name = None | |
| if self.device == "cuda": | |
| try: | |
| gpu_mem_alloc_mb = round(torch.cuda.max_memory_allocated() / (1024**2), 2) | |
| gpu_mem_resv_mb = round(torch.cuda.max_memory_reserved() / (1024**2), 2) | |
| gpu_name = torch.cuda.get_device_name(torch.cuda.current_device()) | |
| except Exception: | |
| pass | |
| metrics: Dict[str, Any] = { | |
| **meta, | |
| "gpu_name": gpu_name, | |
| "timings_ms": { | |
| "preprocess": round(pre_ms, 2), | |
| "inference": round(infer_ms, 2), | |
| "decode": round(decode_ms, 2), | |
| "total": round(total_ms, 2), | |
| }, | |
| "tokens": { | |
| "input": input_tokens, | |
| "output": output_tokens, | |
| "total": total_tokens, | |
| }, | |
| "throughput": { | |
| "tokens_per_sec_inference": None if tps_infer is None else round(tps_infer, 2), | |
| "tokens_per_sec_end_to_end": None if tps_total is None else round(tps_total, 2), | |
| }, | |
| "gpu_memory_mb": { | |
| "max_allocated": gpu_mem_alloc_mb, | |
| "max_reserved": gpu_mem_resv_mb, | |
| }, | |
| } | |
| return (text, metrics) if return_stats else text | |
| _runner_singleton = None | |
| def get_runner(): | |
| global _runner_singleton | |
| if _runner_singleton is None: | |
| _runner_singleton = SmolVLMRunner() | |
| return _runner_singleton | |