Luis J Camargo commited on
Commit
0c82e96
·
1 Parent(s): cbab00e

new setts2

Browse files
Files changed (1) hide show
  1. app.py +40 -61
app.py CHANGED
@@ -22,7 +22,6 @@ logging.basicConfig(
22
  logger = logging.getLogger("TachiwinOCR")
23
 
24
  DEVICE = "cpu"
25
- # Speed up CPU inference
26
  torch.set_num_threads(os.cpu_count() or 4)
27
 
28
  PROMPTS = {
@@ -82,45 +81,50 @@ class OCRModelManager(object):
82
  try:
83
  img_path = args[0]
84
  task = kwargs.get("task", "ocr")
85
- min_new_tokens = kwargs.get("min_new_tokens", 3)
86
- max_new_tokens = kwargs.get("max_new_tokens", 1024)
87
  temperature = kwargs.get("temperature", 0.2)
88
-
89
  logger.info(f"--- Starting inference process ---")
90
  logger.info(f"Task: {task}, Min New Tokens: {min_new_tokens}, Temperature: {temperature}")
91
 
92
  image = Image.open(img_path).convert("RGB")
93
-
94
  messages = [
95
- {"role": "user",
96
- "content": [
97
- {"type": "image", "image": image},
98
  {"type": "text", "text": PROMPTS[task]},
99
  ]
100
  }
101
  ]
102
 
103
- inputs = processor.apply_chat_template(
104
- messages,
105
- tokenize=True,
106
- add_generation_prompt=True,
107
- return_dict=True,
108
- return_tensors="pt"
 
 
 
 
 
 
 
109
  ).to(DEVICE)
110
 
111
  logger.info(f"Inputs prepared (shape: {inputs['input_ids'].shape}). Running model.generate...")
112
- with torch.inference_mode():
113
- # Restoring sampling params as requested
114
- # use_cache=False as requested because it's known to be unstable on some setups
115
- outputs = model.generate(
116
- **inputs,
117
- max_new_tokens=max_new_tokens,
118
- min_new_tokens=min_new_tokens,
119
- use_cache=False,
120
- do_sample=True,
121
- temperature=max(temperature, 0.01),
122
- min_p=0.1,
123
- )
124
 
125
  logger.info("Generation complete. Decoding results...")
126
  decoded_outputs = processor.batch_decode(outputs, skip_special_tokens=True)[0]
@@ -138,23 +142,12 @@ def create_model():
138
  model_path = "tachiwin/PaddleOCR-VL-Tachiwin-BF16"
139
  logger.info(f"Loading model and processor from {model_path}...")
140
 
141
- # Use bfloat16 for CPU if supported, else float32
142
- # Hugging Face spaces CPUs often support bfloat16
143
- try:
144
- model = AutoModelForCausalLM.from_pretrained(
145
- model_path,
146
- trust_remote_code=True,
147
- torch_dtype=torch.bfloat16
148
- ).to(DEVICE).eval()
149
- logger.info(f"Model loaded on {DEVICE} with bfloat16")
150
- except Exception as e:
151
- logger.warning(f"Failed to load in bfloat16, falling back to float32: {e}")
152
- model = AutoModelForCausalLM.from_pretrained(
153
- model_path,
154
- trust_remote_code=True,
155
- torch_dtype=torch.float32
156
- ).to(DEVICE).eval()
157
- logger.info(f"Model loaded on {DEVICE} with float32")
158
 
159
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
160
  logger.info(f"Processor loaded successfully.")
@@ -184,26 +177,12 @@ def inference(img):
184
  return model_manager.infer(
185
  img,
186
  task="ocr",
187
- min_new_tokens=3,
188
- max_new_tokens=1024,
 
 
189
  )
190
 
191
- # # Now extract text from the serialized structure
192
- # extracted_texts = []
193
-
194
- # for page in serialized_result:
195
- # if isinstance(page, dict) and 'parsing_res_list' in page:
196
- # for block in page['parsing_res_list']:
197
- # if isinstance(block, dict) and 'content' in block and block['content']:
198
- # extracted_texts.append(block['content'])
199
-
200
- # if not extracted_texts:
201
- # # return json as string
202
- # return json.dumps(serialized_result, indent=4)
203
-
204
- # # Join all text blocks with double newlines
205
- # return "\n\n".join(extracted_texts)
206
-
207
  except Exception as e:
208
  import traceback
209
  error_detail = traceback.format_exc()
 
22
  logger = logging.getLogger("TachiwinOCR")
23
 
24
  DEVICE = "cpu"
 
25
  torch.set_num_threads(os.cpu_count() or 4)
26
 
27
  PROMPTS = {
 
81
  try:
82
  img_path = args[0]
83
  task = kwargs.get("task", "ocr")
84
+ min_new_tokens = kwargs.get("min_new_tokens", 1)
85
+ max_new_tokens = kwargs.get("max_new_tokens", 128)
86
  temperature = kwargs.get("temperature", 0.2)
87
+ min_p = kwargs.get("min_p", 0.1)
88
  logger.info(f"--- Starting inference process ---")
89
  logger.info(f"Task: {task}, Min New Tokens: {min_new_tokens}, Temperature: {temperature}")
90
 
91
  image = Image.open(img_path).convert("RGB")
92
+
93
  messages = [
94
+ {"role": "user",
95
+ "content": [
96
+ {"type": "image"},
97
  {"type": "text", "text": PROMPTS[task]},
98
  ]
99
  }
100
  ]
101
 
102
+ text_prompt = processor.tokenizer.apply_chat_template(
103
+ messages,
104
+ tokenize=False,
105
+ add_generation_prompt=True
106
+ )
107
+
108
+ logger.info(f"Text prompt: {text_prompt}")
109
+
110
+ inputs = processor(
111
+ image,
112
+ text_prompt,
113
+ add_special_tokens=False,
114
+ return_tensors="pt",
115
  ).to(DEVICE)
116
 
117
  logger.info(f"Inputs prepared (shape: {inputs['input_ids'].shape}). Running model.generate...")
118
+ logger.info(inputs)
119
+ outputs = model.generate(
120
+ **inputs,
121
+ max_new_tokens=max_new_tokens,
122
+ min_new_tokens=min_new_tokens,
123
+ use_cache=False,
124
+ do_sample=True,
125
+ temperature=temperature,
126
+ min_p=min_p,
127
+ )
 
 
128
 
129
  logger.info("Generation complete. Decoding results...")
130
  decoded_outputs = processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
142
  model_path = "tachiwin/PaddleOCR-VL-Tachiwin-BF16"
143
  logger.info(f"Loading model and processor from {model_path}...")
144
 
145
+ model = AutoModelForCausalLM.from_pretrained(
146
+ model_path,
147
+ trust_remote_code=True,
148
+ torch_dtype=torch.bfloat16
149
+ ).to(DEVICE).eval()
150
+ logger.info(f"Model loaded on {DEVICE} with bfloat16")
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
153
  logger.info(f"Processor loaded successfully.")
 
177
  return model_manager.infer(
178
  img,
179
  task="ocr",
180
+ min_new_tokens=1,
181
+ max_new_tokens=128,
182
+ temperature=1.5,
183
+ min_p=0.1,
184
  )
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  except Exception as e:
187
  import traceback
188
  error_detail = traceback.format_exc()