Geraldine commited on
Commit
28c56d5
·
verified ·
1 Parent(s): ccde86d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -2
app.py CHANGED
@@ -51,12 +51,12 @@ model_v = Qwen2_5_VLForConditionalGeneration.from_pretrained(
51
  ).to(device).eval()
52
 
53
  MODEL_ID_Y = "rednote-hilab/dots.ocr"
54
- processor_y = AutoProcessor.from_pretrained(MODEL_ID_V, trust_remote_code=True)
55
  model_y = AutoModelForCausalLM.from_pretrained(
56
  MODEL_ID_Y,
57
  attn_implementation="kernels-community/flash-attn2",
58
  trust_remote_code=True,
59
- torch_dtype=torch.float16
60
  ).to(device).eval()
61
 
62
  MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
@@ -319,6 +319,22 @@ def calc_timeout_duration(*args, **kwargs):
319
  return 60
320
 
321
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
  @spaces.GPU(duration=calc_timeout_duration)
323
  def generate_image(model_name, text, image, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_timeout):
324
  try:
@@ -359,6 +375,7 @@ def generate_image(model_name, text, image, max_new_tokens, temperature, top_p,
359
  truncation=True,
360
  max_length=MAX_INPUT_TOKEN_LENGTH
361
  ).to(device)
 
362
 
363
  streamer = TextIteratorStreamer(
364
  processor.tokenizer if hasattr(processor, "tokenizer") else processor,
 
51
  ).to(device).eval()
52
 
53
  MODEL_ID_Y = "rednote-hilab/dots.ocr"
54
+ processor_y = AutoProcessor.from_pretrained(MODEL_ID_Y, trust_remote_code=True)
55
  model_y = AutoModelForCausalLM.from_pretrained(
56
  MODEL_ID_Y,
57
  attn_implementation="kernels-community/flash-attn2",
58
  trust_remote_code=True,
59
+ torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
60
  ).to(device).eval()
61
 
62
  MODEL_ID_X = "prithivMLmods/Qwen2-VL-OCR-2B-Instruct"
 
319
  return 60
320
 
321
 
322
+ def align_inputs_to_model_dtype(inputs, model):
323
+ model_dtype = getattr(model, "dtype", None)
324
+ if model_dtype is None:
325
+ try:
326
+ model_dtype = next(model.parameters()).dtype
327
+ except StopIteration:
328
+ model_dtype = None
329
+ if model_dtype is None:
330
+ return inputs
331
+
332
+ for key, value in list(inputs.items()):
333
+ if torch.is_tensor(value) and value.is_floating_point():
334
+ inputs[key] = value.to(dtype=model_dtype)
335
+ return inputs
336
+
337
+
338
  @spaces.GPU(duration=calc_timeout_duration)
339
  def generate_image(model_name, text, image, max_new_tokens, temperature, top_p, top_k, repetition_penalty, gpu_timeout):
340
  try:
 
375
  truncation=True,
376
  max_length=MAX_INPUT_TOKEN_LENGTH
377
  ).to(device)
378
+ inputs = align_inputs_to_model_dtype(inputs, model)
379
 
380
  streamer = TextIteratorStreamer(
381
  processor.tokenizer if hasattr(processor, "tokenizer") else processor,