reygml commited on
Commit
43e6aa2
·
1 Parent(s): 71fd7ae
Files changed (1) hide show
  1. util.py +31 -48
util.py CHANGED
@@ -1,9 +1,22 @@
1
- # util.py
 
2
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import threading
4
  from io import BytesIO
5
- from typing import List, Sequence, Union
6
-
7
  import torch
8
  from PIL import Image
9
  from transformers import AutoProcessor, AutoModelForVision2Seq
@@ -11,17 +24,13 @@ from transformers.image_utils import load_image as hf_load_image
11
 
12
 
13
  class SmolVLMRunner:
14
- """
15
- Thin wrapper around HuggingFaceTB/SmolVLM-Instruct for single/multi-image VQA or captioning.
16
- Reuses a single model instance across calls and serializes inference with a lock (GPU friendly).
17
- """
18
-
19
  def __init__(self, model_id: str | None = None, device: str | None = None):
20
  self.model_id = model_id or os.getenv("SMOLVLM_MODEL_ID", "HuggingFaceTB/SmolVLM-Instruct")
21
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
22
  self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
23
 
24
- self.processor = AutoProcessor.from_pretrained(self.model_id)
 
25
 
26
  attn_impl = "flash_attention_2" if self.device == "cuda" else "eager"
27
  try:
@@ -29,63 +38,40 @@ class SmolVLMRunner:
29
  self.model_id,
30
  torch_dtype=self.dtype,
31
  _attn_implementation=attn_impl,
 
32
  ).to(self.device)
33
  except Exception:
34
- # Fallback if flash-attn isn't available
35
  self.model = AutoModelForVision2Seq.from_pretrained(
36
  self.model_id,
37
  torch_dtype=self.dtype,
38
  _attn_implementation="eager",
 
39
  ).to(self.device)
40
 
41
  self.model.eval()
42
  self._lock = threading.Lock()
43
 
44
- # ---------- Image loading helpers ----------
45
-
46
  @staticmethod
47
  def _ensure_rgb(img: Image.Image) -> Image.Image:
48
  return img.convert("RGB") if img.mode != "RGB" else img
49
 
50
  @classmethod
51
  def load_pil_from_urls(cls, urls: Sequence[str]) -> List[Image.Image]:
52
- """Load images from HTTP/HTTPS URLs using HF's helper."""
53
- images: List[Image.Image] = []
54
- for u in urls:
55
- img = hf_load_image(u)
56
- images.append(cls._ensure_rgb(img))
57
- return images
58
 
59
  @classmethod
60
  def load_pil_from_bytes(cls, blobs: Sequence[bytes]) -> List[Image.Image]:
61
- """Load images from raw bytes (e.g., FastAPI uploads)."""
62
- images: List[Image.Image] = []
63
- for b in blobs:
64
- img = Image.open(BytesIO(b))
65
- images.append(cls._ensure_rgb(img))
66
- return images
67
-
68
- # ---------- Core inference ----------
69
-
70
- def generate(
71
- self,
72
- prompt: str,
73
- images: Sequence[Image.Image],
74
- max_new_tokens: int = 300,
75
- temperature: float | None = None,
76
- top_p: float | None = None,
77
- ) -> str:
78
- """
79
- Run generation with 0+ images (text-only works too).
80
- """
81
- # Build chat template: one "image" token per provided image, then the text.
82
  content = [{"type": "image"} for _ in images] + [{"type": "text", "text": prompt}]
83
  messages = [{"role": "user", "content": content}]
84
-
85
  chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
86
 
87
  inputs = self.processor(text=chat_prompt, images=list(images), return_tensors="pt")
88
- inputs = {k: v.to(self.device) if hasattr(v, "to") else v for k, v in inputs.items()}
89
 
90
  gen_kwargs = dict(max_new_tokens=max_new_tokens)
91
  if temperature is not None:
@@ -97,19 +83,16 @@ class SmolVLMRunner:
97
  generated_ids = self.model.generate(**inputs, **gen_kwargs)
98
 
99
  text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
100
- # Many chat templates prepend "Assistant: "
101
  if text.startswith("Assistant:"):
102
- text = text[len("Assistant:") :].strip()
103
  return text
104
 
105
 
106
- # Convenience singleton (optional import path)
107
- _runner_singleton: SmolVLMRunner | None = None
108
-
109
-
110
- def get_runner() -> SmolVLMRunner:
111
  global _runner_singleton
112
  if _runner_singleton is None:
113
  _runner_singleton = SmolVLMRunner()
114
  return _runner_singleton
115
 
 
 
1
+
2
+ # util.py (patched cache handling for HF Spaces)
3
  import os
4
+ from pathlib import Path
5
+
6
+ # Put every cache under /tmp (always writable in Spaces)
7
+ CACHE_DIR = os.getenv("HF_CACHE_DIR", "/tmp/hf-cache")
8
+ Path(CACHE_DIR).mkdir(parents=True, exist_ok=True)
9
+
10
+ # Make sure libraries don't fall back to "~/.cache" -> "/.cache"
11
+ os.environ.setdefault("HF_HOME", CACHE_DIR)
12
+ os.environ.setdefault("TRANSFORMERS_CACHE", CACHE_DIR)
13
+ os.environ.setdefault("HUGGINGFACE_HUB_CACHE", CACHE_DIR)
14
+ os.environ.setdefault("XDG_CACHE_HOME", CACHE_DIR)
15
+ os.environ.setdefault("TORCH_HOME", CACHE_DIR)
16
+
17
  import threading
18
  from io import BytesIO
19
+ from typing import List, Sequence
 
20
  import torch
21
  from PIL import Image
22
  from transformers import AutoProcessor, AutoModelForVision2Seq
 
24
 
25
 
26
  class SmolVLMRunner:
 
 
 
 
 
27
  def __init__(self, model_id: str | None = None, device: str | None = None):
28
  self.model_id = model_id or os.getenv("SMOLVLM_MODEL_ID", "HuggingFaceTB/SmolVLM-Instruct")
29
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
30
  self.dtype = torch.bfloat16 if self.device == "cuda" else torch.float32
31
 
32
+ # Use the writable cache dir explicitly
33
+ self.processor = AutoProcessor.from_pretrained(self.model_id, cache_dir=CACHE_DIR)
34
 
35
  attn_impl = "flash_attention_2" if self.device == "cuda" else "eager"
36
  try:
 
38
  self.model_id,
39
  torch_dtype=self.dtype,
40
  _attn_implementation=attn_impl,
41
+ cache_dir=CACHE_DIR,
42
  ).to(self.device)
43
  except Exception:
44
+ # Fallback if flash-attn isn't available in the environment
45
  self.model = AutoModelForVision2Seq.from_pretrained(
46
  self.model_id,
47
  torch_dtype=self.dtype,
48
  _attn_implementation="eager",
49
+ cache_dir=CACHE_DIR,
50
  ).to(self.device)
51
 
52
  self.model.eval()
53
  self._lock = threading.Lock()
54
 
 
 
55
  @staticmethod
56
  def _ensure_rgb(img: Image.Image) -> Image.Image:
57
  return img.convert("RGB") if img.mode != "RGB" else img
58
 
59
  @classmethod
60
  def load_pil_from_urls(cls, urls: Sequence[str]) -> List[Image.Image]:
61
+ return [cls._ensure_rgb(hf_load_image(u)) for u in urls]
 
 
 
 
 
62
 
63
  @classmethod
64
  def load_pil_from_bytes(cls, blobs: Sequence[bytes]) -> List[Image.Image]:
65
+ return [cls._ensure_rgb(Image.open(BytesIO(b))) for b in blobs]
66
+
67
+ def generate(self, prompt: str, images: Sequence[Image.Image], max_new_tokens: int = 300,
68
+ temperature: float | None = None, top_p: float | None = None) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
  content = [{"type": "image"} for _ in images] + [{"type": "text", "text": prompt}]
70
  messages = [{"role": "user", "content": content}]
 
71
  chat_prompt = self.processor.apply_chat_template(messages, add_generation_prompt=True)
72
 
73
  inputs = self.processor(text=chat_prompt, images=list(images), return_tensors="pt")
74
+ inputs = {k: (v.to(self.device) if hasattr(v, "to") else v) for k, v in inputs.items()}
75
 
76
  gen_kwargs = dict(max_new_tokens=max_new_tokens)
77
  if temperature is not None:
 
83
  generated_ids = self.model.generate(**inputs, **gen_kwargs)
84
 
85
  text = self.processor.batch_decode(generated_ids, skip_special_tokens=True)[0].strip()
 
86
  if text.startswith("Assistant:"):
87
+ text = text[len("Assistant:"):].strip()
88
  return text
89
 
90
 
91
+ _runner_singleton = None
92
+ def get_runner():
 
 
 
93
  global _runner_singleton
94
  if _runner_singleton is None:
95
  _runner_singleton = SmolVLMRunner()
96
  return _runner_singleton
97
 
98
+