longjava2024 commited on
Commit
1edb9b1
·
verified ·
1 Parent(s): ddd807c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -6
app.py CHANGED
@@ -28,8 +28,9 @@ app = FastAPI(title="CCCD OCR with Vintern-1B-v2")
28
 
29
 
30
  MODEL_NAME = "5CD-AI/Vintern-1B-v2"
31
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
32
- DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32
 
33
 
34
 
35
  print(f"Loading model `{MODEL_NAME}` on {DEVICE} ...")
@@ -237,10 +238,7 @@ def normalize_base64(image_base64: str) -> str:
237
 
238
  def ocr_by_llm(image_base64: str, prompt: str) -> str:
239
  pixel_values = load_image_from_base64(image_base64, max_num=6)
240
- if DEVICE == "cuda":
241
- pixel_values = pixel_values.to(dtype=torch.bfloat16, device=DEVICE)
242
- else:
243
- pixel_values = pixel_values.to(dtype=torch.float32, device=DEVICE)
244
  with torch.no_grad():
245
  response_message = model.chat(
246
  tokenizer,
 
28
 
29
 
30
  MODEL_NAME = "5CD-AI/Vintern-1B-v2"
31
+ # Force CPU-only to avoid NVIDIA driver / CUDA issues on Spaces
32
+ DEVICE = "cpu"
33
+ DTYPE = torch.float32
34
 
35
 
36
  print(f"Loading model `{MODEL_NAME}` on {DEVICE} ...")
 
238
 
239
  def ocr_by_llm(image_base64: str, prompt: str) -> str:
240
  pixel_values = load_image_from_base64(image_base64, max_num=6)
241
+ pixel_values = pixel_values.to(dtype=torch.float32, device=DEVICE)
 
 
 
242
  with torch.no_grad():
243
  response_message = model.chat(
244
  tokenizer,