Luis J Camargo commited on
Commit
2886d21
·
1 Parent(s): 0c82e96

refactor 2

Browse files
Files changed (1) hide show
  1. app.py +104 -161
app.py CHANGED
@@ -1,19 +1,18 @@
1
  import os
2
  import torch
3
- from transformers import AutoModelForCausalLM, AutoProcessor
4
  from PIL import Image
5
  import gradio as gr
6
- from queue import Queue, Empty
7
- from threading import Event, Thread
8
- import atexit
9
-
10
- CONCURRENCY_LIMIT = 1
11
-
12
-
13
  import logging
14
  import sys
15
 
16
- # Configure logging to sys.stderr which is often more reliable in containerized environments
 
 
 
 
 
17
  logging.basicConfig(
18
  level=logging.INFO,
19
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
@@ -21,7 +20,7 @@ logging.basicConfig(
21
  )
22
  logger = logging.getLogger("TachiwinOCR")
23
 
24
- DEVICE = "cpu"
25
  torch.set_num_threads(os.cpu_count() or 4)
26
 
27
  PROMPTS = {
@@ -31,165 +30,106 @@ PROMPTS = {
31
  "chart": "Chart Recognition:",
32
  }
33
 
34
- class OCRModelManager(object):
35
- def __init__(self, num_workers, model_factory):
36
- super().__init__()
37
- self._model_factory = model_factory
38
- self._queue = Queue()
39
- self._workers = []
40
- self._model_initialized_event = Event()
41
- for _ in range(num_workers):
42
- worker = Thread(target=self._worker, daemon=True)
43
- worker.start()
44
- self._model_initialized_event.wait()
45
- self._model_initialized_event.clear()
46
- self._workers.append(worker)
47
-
48
- def infer(self, *args, **kwargs):
49
- result_queue = Queue(maxsize=1)
50
- self._queue.put((args, kwargs, result_queue))
51
-
52
- # Increased timeout to 20 minutes for CPU inference
53
- timeout = 1200
54
- try:
55
- success, payload = result_queue.get(timeout=timeout)
56
- if success:
57
- return payload
58
- else:
59
- raise payload
60
- except Empty:
61
- # Check if workers are still alive
62
- alive = any(w.is_alive() for w in self._workers)
63
- if not alive:
64
- raise RuntimeError("OCR workers have crashed.")
65
- raise RuntimeError(f"OCR inference timed out after {timeout} seconds.")
66
-
67
- def close(self):
68
- for _ in self._workers:
69
- self._queue.put(None)
70
- for worker in self._workers:
71
- worker.join()
72
-
73
- def _worker(self):
74
- model, processor = self._model_factory()
75
- self._model_initialized_event.set()
76
- while True:
77
- item = self._queue.get()
78
- if item is None:
79
- break
80
- args, kwargs, result_queue = item
81
- try:
82
- img_path = args[0]
83
- task = kwargs.get("task", "ocr")
84
- min_new_tokens = kwargs.get("min_new_tokens", 1)
85
- max_new_tokens = kwargs.get("max_new_tokens", 128)
86
- temperature = kwargs.get("temperature", 0.2)
87
- min_p = kwargs.get("min_p", 0.1)
88
- logger.info(f"--- Starting inference process ---")
89
- logger.info(f"Task: {task}, Min New Tokens: {min_new_tokens}, Temperature: {temperature}")
90
-
91
- image = Image.open(img_path).convert("RGB")
92
 
93
- messages = [
94
- {"role": "user",
95
- "content": [
96
- {"type": "image"},
97
- {"type": "text", "text": PROMPTS[task]},
98
- ]
99
- }
100
- ]
101
-
102
- text_prompt = processor.tokenizer.apply_chat_template(
103
- messages,
104
- tokenize=False,
105
- add_generation_prompt=True
106
- )
107
-
108
- logger.info(f"Text prompt: {text_prompt}")
109
-
110
- inputs = processor(
111
- image,
112
- text_prompt,
113
- add_special_tokens=False,
114
- return_tensors="pt",
115
- ).to(DEVICE)
116
-
117
- logger.info(f"Inputs prepared (shape: {inputs['input_ids'].shape}). Running model.generate...")
118
- logger.info(inputs)
119
- outputs = model.generate(
120
- **inputs,
121
- max_new_tokens=max_new_tokens,
122
- min_new_tokens=min_new_tokens,
123
- use_cache=False,
124
- do_sample=True,
125
- temperature=temperature,
126
- min_p=min_p,
127
- )
128
-
129
- logger.info("Generation complete. Decoding results...")
130
- decoded_outputs = processor.batch_decode(outputs, skip_special_tokens=True)[0]
131
- logger.info(f"Inference finished successfully.")
132
-
133
- result_queue.put((True, decoded_outputs))
134
- except Exception as e:
135
- result_queue.put((False, e))
136
- finally:
137
- self._queue.task_done()
138
-
139
-
140
- def create_model():
141
- """Initialize PaddleOCR-VL with the fine-tuned Tachiwin model using transformers"""
142
- model_path = "tachiwin/PaddleOCR-VL-Tachiwin-BF16"
143
- logger.info(f"Loading model and processor from {model_path}...")
144
 
 
145
  model = AutoModelForCausalLM.from_pretrained(
146
  model_path,
147
  trust_remote_code=True,
148
- torch_dtype=torch.bfloat16
149
  ).to(DEVICE).eval()
150
- logger.info(f"Model loaded on {DEVICE} with bfloat16")
151
-
152
- processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
153
- logger.info(f"Processor loaded successfully.")
154
- return model, processor
155
-
156
-
157
- # Initialize model manager with 1 worker to save memory on CPU space
158
- logger.info("Initializing Tachiwin Indigenous Languages OCR model manager...")
159
- model_manager = OCRModelManager(1, create_model)
160
- logger.info("Model manager is ready and listening for tasks!")
161
-
162
-
163
- def close_model_manager():
164
- model_manager.close()
165
-
166
-
167
- atexit.register(close_model_manager)
168
-
169
 
170
  def inference(img):
171
- """Process image with OCR and return extracted text in markdown format"""
 
 
172
  if img is None:
173
- return "Please upload an image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
 
175
- gr.Info("Inference started. On CPU, this may take 2-10 minutes depending on image complexity.")
176
  try:
177
- return model_manager.infer(
178
- img,
179
- task="ocr",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  min_new_tokens=1,
181
- max_new_tokens=128,
182
- temperature=1.5,
183
  min_p=0.1,
 
184
  )
185
-
 
 
 
 
 
 
 
 
 
 
 
 
186
  except Exception as e:
187
  import traceback
188
  error_detail = traceback.format_exc()
189
- print(e)
190
- return f"Error during OCR processing:\n\n```\n{error_detail}\n```"
191
 
192
 
 
 
193
  title = '🌎 Tachiwin OCR for the Indigenous Languages of Mexico'
194
 
195
  description = '''
@@ -198,10 +138,8 @@ description = '''
198
  This model represents a **world first in tech access and linguistic rights**, specifically trained to recognize
199
  the diverse character and glyph repertoire of Mexico's 68 indigenous languages.
200
 
201
- **How to use:** Simply upload an image containing text in any Mexican indigenous language, and the model will
202
- detect and recognize the text.
203
-
204
- ### Warning: as this free demonstrator space uses only CPU, a small image could take up to 5 minutes, so be patient.
205
 
206
  🔗 [PaddleOCR Documentation](https://github.com/PaddlePaddle/PaddleOCR)
207
  '''
@@ -230,12 +168,14 @@ example_labels = """
230
 
231
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;} .output_markdown {min-height: 30rem !important;}"
232
 
233
- gr.Interface(
234
- inference,
235
- [
236
- gr.Image(type='filepath', label='Input'),
237
- ],
238
- gr.Markdown(label='Output', elem_classes="output_markdown"),
 
 
239
  title=title,
240
  description=description,
241
  examples=examples,
@@ -265,4 +205,7 @@ gr.Interface(
265
 
266
  Made with ❤️ for linguistic diversity and indigenous rights
267
  """
268
- ).launch(debug=True)
 
 
 
 
1
  import os
2
  import torch
3
+ from transformers import AutoModelForCausalLM, AutoProcessor, TextIteratorStreamer
4
  from PIL import Image
5
  import gradio as gr
6
+ from threading import Thread
 
 
 
 
 
 
7
  import logging
8
  import sys
9
 
10
+ # --- Configuration ---
11
+ CONCURRENCY_LIMIT = 1
12
+ DEVICE = "cpu"
13
+ DTYPE = torch.float32
14
+
15
+ # Configure logging
16
  logging.basicConfig(
17
  level=logging.INFO,
18
  format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
 
20
  )
21
  logger = logging.getLogger("TachiwinOCR")
22
 
23
+ # Set CPU threads
24
  torch.set_num_threads(os.cpu_count() or 4)
25
 
26
  PROMPTS = {
 
30
  "chart": "Chart Recognition:",
31
  }
32
 
33
+ # --- Global Model Loading ---
34
+ # We load the model globally so it persists across requests.
35
+ # No need for a custom Manager class.
36
+ model_path = "tachiwin/PaddleOCR-VL-Tachiwin-BF16"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ try:
39
+ logger.info(f"Loading processor from {model_path}...")
40
+ processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ logger.info(f"Loading model from {model_path}...")
43
  model = AutoModelForCausalLM.from_pretrained(
44
  model_path,
45
  trust_remote_code=True,
46
+ torch_dtype=DTYPE
47
  ).to(DEVICE).eval()
48
+ logger.info("Model loaded successfully.")
49
+ except Exception as e:
50
+ logger.error(f"Failed to load model: {e}")
51
+ raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
  def inference(img):
54
+ """
55
+ Process image with OCR and Stream the extracted text.
56
+ """
57
  if img is None:
58
+ yield "Please upload an image."
59
+ return
60
+
61
+ # Basic cleanup
62
+ if isinstance(img, str):
63
+ image = Image.open(img).convert("RGB")
64
+ else:
65
+ image = Image.fromarray(img).convert("RGB")
66
+
67
+ task = "ocr"
68
+
69
+ # Prepare inputs
70
+ messages = [
71
+ {"role": "user",
72
+ "content": [
73
+ {"type": "image"},
74
+ {"type": "text", "text": PROMPTS[task]},
75
+ ]
76
+ }
77
+ ]
78
 
 
79
  try:
80
+ text_prompt = processor.tokenizer.apply_chat_template(
81
+ messages,
82
+ tokenize=False,
83
+ add_generation_prompt=True
84
+ )
85
+
86
+ inputs = processor(
87
+ image,
88
+ text_prompt,
89
+ add_special_tokens=False,
90
+ return_tensors="pt",
91
+ ).to(DEVICE)
92
+
93
+ # Initialize Streamer
94
+ streamer = TextIteratorStreamer(
95
+ processor.tokenizer,
96
+ skip_prompt=True,
97
+ skip_special_tokens=True
98
+ )
99
+
100
+ # Generation Arguments
101
+ generation_kwargs = dict(
102
+ **inputs,
103
+ streamer=streamer,
104
+ max_new_tokens=256, # Increased slightly
105
  min_new_tokens=1,
106
+ do_sample=True,
107
+ temperature=1.5, # Adjusted slightly for stability
108
  min_p=0.1,
109
+ use_cache=False # Cache helps speed on CPU significantly
110
  )
111
+
112
+ # Threading is REQUIRED for streaming
113
+ # The model generates in a separate thread, while the main thread
114
+ # yields from the streamer iterator.
115
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
116
+ thread.start()
117
+
118
+ generated_text = ""
119
+ for new_text in streamer:
120
+ generated_text += new_text
121
+ # Yielding here updates the Gradio textbox in real-time
122
+ yield generated_text
123
+
124
  except Exception as e:
125
  import traceback
126
  error_detail = traceback.format_exc()
127
+ logger.error(error_detail)
128
+ yield f"Error during OCR processing:\n\n```\n{error_detail}\n```"
129
 
130
 
131
+ # --- Interface Setup ---
132
+
133
  title = '🌎 Tachiwin OCR for the Indigenous Languages of Mexico'
134
 
135
  description = '''
 
138
  This model represents a **world first in tech access and linguistic rights**, specifically trained to recognize
139
  the diverse character and glyph repertoire of Mexico's 68 indigenous languages.
140
 
141
+ **How to use:** Simply upload an image containing text in any Mexican indigenous language.
142
+ **Note:** Running on CPU. Streaming is enabled so you can see progress immediately.
 
 
143
 
144
  🔗 [PaddleOCR Documentation](https://github.com/PaddlePaddle/PaddleOCR)
145
  '''
 
168
 
169
  css = ".output_image, .input_image {height: 40rem !important; width: 100% !important;} .output_markdown {min-height: 30rem !important;}"
170
 
171
+ # Note: We replaced gr.Interface with gr.Blocks or used the generator compatible interface
172
+ # But standard Interface supports generators in newer Gradio versions.
173
+ # Just ensuring concurrency_limit is set.
174
+
175
+ demo = gr.Interface(
176
+ fn=inference,
177
+ inputs=gr.Image(type='filepath', label='Input'),
178
+ outputs=gr.Markdown(label='Output', elem_classes="output_markdown"),
179
  title=title,
180
  description=description,
181
  examples=examples,
 
205
 
206
  Made with ❤️ for linguistic diversity and indigenous rights
207
  """
208
+ )
209
+
210
+ if __name__ == "__main__":
211
+ demo.queue().launch()