aal-hawa commited on
Commit
f328abe
·
1 Parent(s): 6aefbe1
Files changed (1) hide show
  1. app.py +14 -10
app.py CHANGED
@@ -79,18 +79,22 @@ def ocr_process(image):
79
  )
80
 
81
  # The processor outputs bfloat16 tensors, but model is float32.
82
- # Convert all floating-point input tensors to float32.
83
- for key in inputs:
84
- if isinstance(inputs[key], torch.Tensor):
85
- if inputs[key].is_floating_point():
86
- inputs[key] = inputs[key].float()
87
-
88
- inputs = inputs.to("cpu")
 
 
 
 
89
 
90
  with torch.no_grad():
91
- generated_ids = model.generate(**inputs, max_new_tokens=16384, do_sample=False)
92
 
93
- input_ids = inputs["input_ids"]
94
  generated_ids_trimmed = [
95
  out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
96
  ]
@@ -124,4 +128,4 @@ with gr.Blocks(title="HunyuanOCR") as demo:
124
  image_input.change(ocr_process, image_input, ocr_output)
125
 
126
  if __name__ == "__main__":
127
- demo.launch(server_name="0.0.0.0")
 
79
  )
80
 
81
  # The processor outputs bfloat16 tensors, but model is float32.
82
+ # BatchFeature doesn't support in-place modification well,
83
+ # so rebuild as a plain dict with float32 tensors.
84
+ clean_inputs = {}
85
+ for k, v in inputs.items():
86
+ if isinstance(v, torch.Tensor):
87
+ if v.dtype == torch.bfloat16:
88
+ clean_inputs[k] = v.to(torch.float32)
89
+ else:
90
+ clean_inputs[k] = v
91
+ else:
92
+ clean_inputs[k] = v
93
 
94
  with torch.no_grad():
95
+ generated_ids = model.generate(**clean_inputs, max_new_tokens=16384, do_sample=False)
96
 
97
+ input_ids = clean_inputs["input_ids"]
98
  generated_ids_trimmed = [
99
  out_ids[len(in_ids):] for in_ids, out_ids in zip(input_ids, generated_ids)
100
  ]
 
128
  image_input.change(ocr_process, image_input, ocr_output)
129
 
130
  if __name__ == "__main__":
131
+ demo.launch(server_name="0.0.0.0")