Luis J Camargo commited on
Commit
cbab00e
·
1 Parent(s): db7023c
Files changed (1) hide show
  1. app.py +54 -17
app.py CHANGED
@@ -1,15 +1,30 @@
 
1
  import torch
2
  from transformers import AutoModelForCausalLM, AutoProcessor
3
  from PIL import Image
4
  import gradio as gr
5
- from queue import Queue
6
  from threading import Event, Thread
7
  import atexit
8
 
9
  CONCURRENCY_LIMIT = 1
10
 
11
 
 
 
 
 
 
 
 
 
 
 
 
12
  DEVICE = "cpu"
 
 
 
13
  PROMPTS = {
14
  "ocr": "OCR:",
15
  "table": "Table Recognition:",
@@ -34,11 +49,21 @@ class OCRModelManager(object):
34
  def infer(self, *args, **kwargs):
35
  result_queue = Queue(maxsize=1)
36
  self._queue.put((args, kwargs, result_queue))
37
- success, payload = result_queue.get()
38
- if success:
39
- return payload
40
- else:
41
- raise payload
 
 
 
 
 
 
 
 
 
 
42
 
43
  def close(self):
44
  for _ in self._workers:
@@ -58,9 +83,12 @@ class OCRModelManager(object):
58
  img_path = args[0]
59
  task = kwargs.get("task", "ocr")
60
  min_new_tokens = kwargs.get("min_new_tokens", 3)
61
- #max_new_tokens = kwargs.get("max_new_tokens", 2048)
62
  temperature = kwargs.get("temperature", 0.2)
63
 
 
 
 
64
  image = Image.open(img_path).convert("RGB")
65
 
66
  messages = [
@@ -80,18 +108,23 @@ class OCRModelManager(object):
80
  return_tensors="pt"
81
  ).to(DEVICE)
82
 
83
- with torch.no_grad():
 
 
 
84
  outputs = model.generate(
85
  **inputs,
86
- #max_new_tokens=max_new_tokens,
87
  min_new_tokens=min_new_tokens,
88
- use_cache=False,
89
- do_sample=False,
 
90
  min_p=0.1,
91
- temperature=temperature if temperature > 0 else 1.0,
92
  )
93
 
 
94
  decoded_outputs = processor.batch_decode(outputs, skip_special_tokens=True)[0]
 
95
 
96
  result_queue.put((True, decoded_outputs))
97
  except Exception as e:
@@ -103,7 +136,7 @@ class OCRModelManager(object):
103
  def create_model():
104
  """Initialize PaddleOCR-VL with the fine-tuned Tachiwin model using transformers"""
105
  model_path = "tachiwin/PaddleOCR-VL-Tachiwin-BF16"
106
- print(f"Loading model and processor from {model_path}...")
107
 
108
  # Use bfloat16 for CPU if supported, else float32
109
  # Hugging Face spaces CPUs often support bfloat16
@@ -113,22 +146,25 @@ def create_model():
113
  trust_remote_code=True,
114
  torch_dtype=torch.bfloat16
115
  ).to(DEVICE).eval()
 
116
  except Exception as e:
117
- print(f"Failed to load in bfloat16, falling back to float32: {e}")
118
  model = AutoModelForCausalLM.from_pretrained(
119
  model_path,
120
  trust_remote_code=True,
121
  torch_dtype=torch.float32
122
  ).to(DEVICE).eval()
 
123
 
124
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
 
125
  return model, processor
126
 
127
 
128
  # Initialize model manager with 1 worker to save memory on CPU space
129
- print("Initializing Tachiwin Indigenous Languages OCR...")
130
  model_manager = OCRModelManager(1, create_model)
131
- print("Model ready!")
132
 
133
 
134
  def close_model_manager():
@@ -143,12 +179,13 @@ def inference(img):
143
  if img is None:
144
  return "Please upload an image."
145
 
 
146
  try:
147
  return model_manager.infer(
148
  img,
149
  task="ocr",
150
  min_new_tokens=3,
151
- temperature=1.5,
152
  )
153
 
154
  # # Now extract text from the serialized structure
 
1
+ import os
2
  import torch
3
  from transformers import AutoModelForCausalLM, AutoProcessor
4
  from PIL import Image
5
  import gradio as gr
6
+ from queue import Queue, Empty
7
  from threading import Event, Thread
8
  import atexit
9
 
10
  CONCURRENCY_LIMIT = 1
11
 
12
 
13
+ import logging
14
+ import sys
15
+
16
+ # Configure logging to sys.stderr which is often more reliable in containerized environments
17
+ logging.basicConfig(
18
+ level=logging.INFO,
19
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
20
+ handlers=[logging.StreamHandler(sys.stderr)]
21
+ )
22
+ logger = logging.getLogger("TachiwinOCR")
23
+
24
  DEVICE = "cpu"
25
+ # Speed up CPU inference
26
+ torch.set_num_threads(os.cpu_count() or 4)
27
+
28
  PROMPTS = {
29
  "ocr": "OCR:",
30
  "table": "Table Recognition:",
 
49
  def infer(self, *args, **kwargs):
50
  result_queue = Queue(maxsize=1)
51
  self._queue.put((args, kwargs, result_queue))
52
+
53
+ # Increased timeout to 20 minutes for CPU inference
54
+ timeout = 1200
55
+ try:
56
+ success, payload = result_queue.get(timeout=timeout)
57
+ if success:
58
+ return payload
59
+ else:
60
+ raise payload
61
+ except Empty:
62
+ # Check if workers are still alive
63
+ alive = any(w.is_alive() for w in self._workers)
64
+ if not alive:
65
+ raise RuntimeError("OCR workers have crashed.")
66
+ raise RuntimeError(f"OCR inference timed out after {timeout} seconds.")
67
 
68
  def close(self):
69
  for _ in self._workers:
 
83
  img_path = args[0]
84
  task = kwargs.get("task", "ocr")
85
  min_new_tokens = kwargs.get("min_new_tokens", 3)
86
+ max_new_tokens = kwargs.get("max_new_tokens", 1024)
87
  temperature = kwargs.get("temperature", 0.2)
88
 
89
+ logger.info(f"--- Starting inference process ---")
90
+ logger.info(f"Task: {task}, Min New Tokens: {min_new_tokens}, Temperature: {temperature}")
91
+
92
  image = Image.open(img_path).convert("RGB")
93
 
94
  messages = [
 
108
  return_tensors="pt"
109
  ).to(DEVICE)
110
 
111
+ logger.info(f"Inputs prepared (shape: {inputs['input_ids'].shape}). Running model.generate...")
112
+ with torch.inference_mode():
113
+ # Restoring sampling params as requested
114
+ # use_cache=False as requested because it's known to be unstable on some setups
115
  outputs = model.generate(
116
  **inputs,
117
+ max_new_tokens=max_new_tokens,
118
  min_new_tokens=min_new_tokens,
119
+ use_cache=False,
120
+ do_sample=True,
121
+ temperature=max(temperature, 0.01),
122
  min_p=0.1,
 
123
  )
124
 
125
+ logger.info("Generation complete. Decoding results...")
126
  decoded_outputs = processor.batch_decode(outputs, skip_special_tokens=True)[0]
127
+ logger.info(f"Inference finished successfully.")
128
 
129
  result_queue.put((True, decoded_outputs))
130
  except Exception as e:
 
136
  def create_model():
137
  """Initialize PaddleOCR-VL with the fine-tuned Tachiwin model using transformers"""
138
  model_path = "tachiwin/PaddleOCR-VL-Tachiwin-BF16"
139
+ logger.info(f"Loading model and processor from {model_path}...")
140
 
141
  # Use bfloat16 for CPU if supported, else float32
142
  # Hugging Face spaces CPUs often support bfloat16
 
146
  trust_remote_code=True,
147
  torch_dtype=torch.bfloat16
148
  ).to(DEVICE).eval()
149
+ logger.info(f"Model loaded on {DEVICE} with bfloat16")
150
  except Exception as e:
151
+ logger.warning(f"Failed to load in bfloat16, falling back to float32: {e}")
152
  model = AutoModelForCausalLM.from_pretrained(
153
  model_path,
154
  trust_remote_code=True,
155
  torch_dtype=torch.float32
156
  ).to(DEVICE).eval()
157
+ logger.info(f"Model loaded on {DEVICE} with float32")
158
 
159
  processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
160
+ logger.info(f"Processor loaded successfully.")
161
  return model, processor
162
 
163
 
164
  # Initialize model manager with 1 worker to save memory on CPU space
165
+ logger.info("Initializing Tachiwin Indigenous Languages OCR model manager...")
166
  model_manager = OCRModelManager(1, create_model)
167
+ logger.info("Model manager is ready and listening for tasks!")
168
 
169
 
170
  def close_model_manager():
 
179
  if img is None:
180
  return "Please upload an image."
181
 
182
+ gr.Info("Inference started. On CPU, this may take 2-10 minutes depending on image complexity.")
183
  try:
184
  return model_manager.infer(
185
  img,
186
  task="ocr",
187
  min_new_tokens=3,
188
+ max_new_tokens=1024,
189
  )
190
 
191
  # # Now extract text from the serialized structure