DocUA commited on
Commit
9efb9c8
·
1 Parent(s): 3537ca8

feat: Improve Hugging Face cache management and enable mixed-precision inference for GPU models.

Browse files
Files changed (1) hide show
  1. app_hf.py +38 -5
app_hf.py CHANGED
@@ -19,6 +19,8 @@ import datetime
19
  import fitz # PyMuPDF
20
  import io
21
  import gc
 
 
22
 
23
  try:
24
  from transformers.models.llama import modeling_llama as _modeling_llama
@@ -52,6 +54,29 @@ warnings.filterwarnings("ignore", message="You are using a model of type .* to i
52
  DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2'
53
  MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it'
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  # --- Device Setup ---
56
  # For HF Spaces with ZeroGPU, we'll use cuda if available
57
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -66,12 +91,13 @@ class ModelManager:
66
  if model_name not in self.models:
67
  print(f"Loading {model_name} to CPU...")
68
  if model_name == DEEPSEEK_MODEL:
69
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
70
  model = AutoModel.from_pretrained(
71
  model_name,
72
  trust_remote_code=True,
73
  use_safetensors=True,
74
  attn_implementation="eager",
 
75
  torch_dtype=dtype
76
  )
77
  model.eval()
@@ -79,10 +105,11 @@ class ModelManager:
79
  self.processors[model_name] = tokenizer
80
 
81
  elif model_name == MEDGEMMA_MODEL:
82
- processor = AutoProcessor.from_pretrained(model_name)
83
  model = AutoModelForImageTextToText.from_pretrained(
84
  model_name,
85
  trust_remote_code=True,
 
86
  torch_dtype=dtype
87
  )
88
  model.eval()
@@ -134,7 +161,7 @@ def run_ocr(input_image, input_file, model_choice, custom_prompt):
134
  model, processor_or_tokenizer = manager.get_model(model_choice)
135
  # Move to GPU only inside the decorated function
136
  print(f"Moving {model_choice} to GPU...")
137
- model.to("cuda")
138
  except Exception as e:
139
  return f"Помилка завантаження чи переміщення моделі: {str(e)}\nЯкщо це MedGemma, переконайтеся, що ви надали HF_TOKEN."
140
 
@@ -144,6 +171,12 @@ def run_ocr(input_image, input_file, model_choice, custom_prompt):
144
  all_results = []
145
 
146
  try:
 
 
 
 
 
 
147
  for i, img in enumerate(images_to_process):
148
  img = img.convert("RGB")
149
  try:
@@ -154,7 +187,7 @@ def run_ocr(input_image, input_file, model_choice, custom_prompt):
154
  tmp_path = tmp.name
155
 
156
  try:
157
- with torch.no_grad():
158
  res = model.infer(
159
  processor_or_tokenizer,
160
  prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ",
@@ -190,7 +223,7 @@ def run_ocr(input_image, input_file, model_choice, custom_prompt):
190
  return_tensors="pt"
191
  ).to("cuda") # Ensure inputs are on cuda
192
 
193
- with torch.no_grad():
194
  output = model.generate(**inputs, max_new_tokens=4096, do_sample=False)
195
 
196
  input_len = inputs["input_ids"].shape[-1]
 
19
  import fitz # PyMuPDF
20
  import io
21
  import gc
22
+ import threading
23
+ import contextlib
24
 
25
  try:
26
  from transformers.models.llama import modeling_llama as _modeling_llama
 
54
  DEEPSEEK_MODEL = 'deepseek-ai/DeepSeek-OCR-2'
55
  MEDGEMMA_MODEL = 'google/medgemma-1.5-4b-it'
56
 
57
+ _default_hf_home = "/data/.huggingface" if os.path.isdir("/data") else os.path.join(os.path.expanduser("~"), ".cache", "huggingface")
58
+ os.environ.setdefault("HF_HOME", _default_hf_home)
59
+ _hf_cache_dir = os.environ.get("HF_HUB_CACHE") or os.path.join(os.environ["HF_HOME"], "hub")
60
+ os.environ.setdefault("HF_HUB_CACHE", _hf_cache_dir)
61
+ os.environ.setdefault("TRANSFORMERS_CACHE", _hf_cache_dir)
62
+
63
+
64
+ def _warmup_hf_cache():
65
+ try:
66
+ from huggingface_hub import snapshot_download
67
+ except Exception as e:
68
+ print(f"Warmup cache failed: {e}")
69
+ return
70
+
71
+ for _repo_id in (DEEPSEEK_MODEL, MEDGEMMA_MODEL):
72
+ try:
73
+ snapshot_download(repo_id=_repo_id, cache_dir=_hf_cache_dir)
74
+ except Exception as e:
75
+ print(f"Warmup cache failed for {_repo_id}: {e}")
76
+
77
+
78
+ threading.Thread(target=_warmup_hf_cache, daemon=True).start()
79
+
80
  # --- Device Setup ---
81
  # For HF Spaces with ZeroGPU, we'll use cuda if available
82
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
91
  if model_name not in self.models:
92
  print(f"Loading {model_name} to CPU...")
93
  if model_name == DEEPSEEK_MODEL:
94
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, cache_dir=_hf_cache_dir)
95
  model = AutoModel.from_pretrained(
96
  model_name,
97
  trust_remote_code=True,
98
  use_safetensors=True,
99
  attn_implementation="eager",
100
+ cache_dir=_hf_cache_dir,
101
  torch_dtype=dtype
102
  )
103
  model.eval()
 
105
  self.processors[model_name] = tokenizer
106
 
107
  elif model_name == MEDGEMMA_MODEL:
108
+ processor = AutoProcessor.from_pretrained(model_name, cache_dir=_hf_cache_dir)
109
  model = AutoModelForImageTextToText.from_pretrained(
110
  model_name,
111
  trust_remote_code=True,
112
+ cache_dir=_hf_cache_dir,
113
  torch_dtype=dtype
114
  )
115
  model.eval()
 
161
  model, processor_or_tokenizer = manager.get_model(model_choice)
162
  # Move to GPU only inside the decorated function
163
  print(f"Moving {model_choice} to GPU...")
164
+ model.to(device="cuda", dtype=torch.float16)
165
  except Exception as e:
166
  return f"Помилка завантаження чи переміщення моделі: {str(e)}\nЯкщо це MedGemma, переконайтеся, що ви надали HF_TOKEN."
167
 
 
171
  all_results = []
172
 
173
  try:
174
+ _autocast_ctx = (
175
+ torch.autocast(device_type="cuda", dtype=torch.float16)
176
+ if torch.cuda.is_available()
177
+ else contextlib.nullcontext()
178
+ )
179
+
180
  for i, img in enumerate(images_to_process):
181
  img = img.convert("RGB")
182
  try:
 
187
  tmp_path = tmp.name
188
 
189
  try:
190
+ with torch.no_grad(), _autocast_ctx:
191
  res = model.infer(
192
  processor_or_tokenizer,
193
  prompt=custom_prompt if custom_prompt else "<image>\nFree OCR. ",
 
223
  return_tensors="pt"
224
  ).to("cuda") # Ensure inputs are on cuda
225
 
226
+ with torch.no_grad(), _autocast_ctx:
227
  output = model.generate(**inputs, max_new_tokens=4096, do_sample=False)
228
 
229
  input_len = inputs["input_ids"].shape[-1]