LLDDWW Claude commited on
Commit
6a2327c
ยท
1 Parent(s): e94e117

chore: switch OCR model to TrOCR

Browse files

- Replace Qwen2-VL with microsoft/trocr-large-printed for OCR
- Update model loading and inference code for TrOCR architecture
- Simplify OCR processing logic

๐Ÿค– Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +15 -40
app.py CHANGED
@@ -6,26 +6,24 @@ import gradio as gr
6
  import spaces
7
  import torch
8
  from PIL import Image
9
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer, AutoModelForCausalLM
10
 
11
- # Stage 1: OCR ๋ชจ๋ธ (Qwen2-VL๋กœ ๋ฌธ์„œ์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ)
12
- OCR_MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
13
 
14
  # Stage 2: LLM ๋ชจ๋ธ (ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„ ์ถ”์ถœ)
15
  LLM_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
16
 
17
 
18
  def _load_ocr_model():
19
- """Qwen2-VL OCR ๋ชจ๋ธ ๋กœ๋“œ"""
20
- model = Qwen2VLForConditionalGeneration.from_pretrained(
21
  OCR_MODEL_ID,
22
  device_map="auto",
23
- load_in_8bit=True,
24
  torch_dtype=torch.float16,
25
- trust_remote_code=True,
26
  )
27
 
28
- processor = AutoProcessor.from_pretrained(OCR_MODEL_ID, trust_remote_code=True)
29
  return model, processor
30
 
31
 
@@ -43,9 +41,9 @@ def _load_llm_model():
43
  return model, tokenizer
44
 
45
 
46
- print("๐Ÿ”„ Loading Qwen2-VL OCR model...")
47
  OCR_MODEL, OCR_PROCESSOR = _load_ocr_model()
48
- print("โœ… OCR model loaded!")
49
 
50
  print("๐Ÿ”„ Loading Qwen2.5-7B-Instruct...")
51
  LLM_MODEL, LLM_TOKENIZER = _load_llm_model()
@@ -70,39 +68,16 @@ def _extract_json_block(text: str) -> Optional[str]:
70
 
71
 
72
  def extract_text_from_image(image: Image.Image) -> str:
73
- """Stage 1: Qwen2-VL๋กœ ์ด๋ฏธ์ง€์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ (OCR)"""
74
  try:
75
- messages = [
76
- {
77
- "role": "user",
78
- "content": [
79
- {"type": "text", "text": "์ด ์ด๋ฏธ์ง€์˜ ๋ชจ๋“  ํ…์ŠคํŠธ๋ฅผ ์ •ํ™•ํžˆ ์ถ”์ถœํ•ด์„œ ๊ทธ๋Œ€๋กœ ์ถœ๋ ฅํ•ด์ฃผ์„ธ์š”. OCR ๊ฒฐ๊ณผ๋งŒ ์ถœ๋ ฅํ•˜์„ธ์š”."},
80
- {"type": "image"},
81
- ],
82
- }
83
- ]
84
-
85
- chat_text = OCR_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
86
- inputs = OCR_PROCESSOR(text=[chat_text], images=[image], return_tensors="pt").to(OCR_MODEL.device)
87
 
88
  with torch.no_grad():
89
- output_ids = OCR_MODEL.generate(
90
- **inputs,
91
- max_new_tokens=1024,
92
- temperature=0.1, # ์ •ํ™•ํ•œ OCR์„ ์œ„ํ•ด ๋‚ฎ์€ temperature
93
- do_sample=False, # ๊ฒฐ์ •์  ์ถœ๋ ฅ
94
- )
95
-
96
- output_text = OCR_PROCESSOR.batch_decode(output_ids, skip_special_tokens=False)[0]
97
-
98
- # Extract assistant response
99
- if "<|im_start|>assistant" in output_text:
100
- extracted_text = output_text.split("<|im_start|>assistant")[-1]
101
- extracted_text = extracted_text.replace("<|im_end|>", "").strip()
102
- else:
103
- extracted_text = output_text.strip()
104
 
105
- return extracted_text
 
106
 
107
  except Exception as e:
108
  raise Exception(f"OCR ์˜ค๋ฅ˜: {str(e)}")
@@ -329,7 +304,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
329
  ---
330
 
331
  **โ„น๏ธ 2๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ**
332
- - **Stage 1**: Qwen2-VL 7B (OCR) - ์ด๋ฏธ์ง€์—์„œ ๋ชจ๋“  ํ…์ŠคํŠธ ์ถ”์ถœ
333
  - **Stage 2**: Qwen2.5 7B (LLM) - ์ถ”์ถœ๋œ ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ์‹๋ณ„
334
 
335
  ์‹ค์ œ ๋ณต์•ฝ์€ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ๋ฅผ ๋”ฐ๋ฅด์„ธ์š”.
 
6
  import spaces
7
  import torch
8
  from PIL import Image
9
+ from transformers import VisionEncoderDecoderModel, TrOCRProcessor, AutoTokenizer, AutoModelForCausalLM
10
 
11
+ # Stage 1: OCR ๋ชจ๋ธ (TrOCR๋กœ ๋ฌธ์„œ์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ)
12
+ OCR_MODEL_ID = "microsoft/trocr-large-printed"
13
 
14
  # Stage 2: LLM ๋ชจ๋ธ (ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„ ์ถ”์ถœ)
15
  LLM_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
16
 
17
 
18
  def _load_ocr_model():
19
+ """TrOCR ๋ชจ๋ธ ๋กœ๋“œ"""
20
+ model = VisionEncoderDecoderModel.from_pretrained(
21
  OCR_MODEL_ID,
22
  device_map="auto",
 
23
  torch_dtype=torch.float16,
 
24
  )
25
 
26
+ processor = TrOCRProcessor.from_pretrained(OCR_MODEL_ID)
27
  return model, processor
28
 
29
 
 
41
  return model, tokenizer
42
 
43
 
44
+ print("๐Ÿ”„ Loading TrOCR model...")
45
  OCR_MODEL, OCR_PROCESSOR = _load_ocr_model()
46
+ print("โœ… TrOCR model loaded!")
47
 
48
  print("๐Ÿ”„ Loading Qwen2.5-7B-Instruct...")
49
  LLM_MODEL, LLM_TOKENIZER = _load_llm_model()
 
68
 
69
 
70
  def extract_text_from_image(image: Image.Image) -> str:
71
+ """Stage 1: TrOCR๋กœ ์ด๋ฏธ์ง€์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ (OCR)"""
72
  try:
73
+ # TrOCR์€ ์ด๋ฏธ์ง€ ์ „์ฒด๋ฅผ ํ•œ ๋ฒˆ์— ์ฒ˜๋ฆฌ
74
+ pixel_values = OCR_PROCESSOR(image, return_tensors="pt").pixel_values.to(OCR_MODEL.device)
 
 
 
 
 
 
 
 
 
 
75
 
76
  with torch.no_grad():
77
+ generated_ids = OCR_MODEL.generate(pixel_values)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
+ extracted_text = OCR_PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
80
+ return extracted_text.strip()
81
 
82
  except Exception as e:
83
  raise Exception(f"OCR ์˜ค๋ฅ˜: {str(e)}")
 
304
  ---
305
 
306
  **โ„น๏ธ 2๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ**
307
+ - **Stage 1**: TrOCR (OCR) - ์ด๋ฏธ์ง€์—์„œ ๋ชจ๋“  ํ…์ŠคํŠธ ์ถ”์ถœ
308
  - **Stage 2**: Qwen2.5 7B (LLM) - ์ถ”์ถœ๋œ ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ์‹๋ณ„
309
 
310
  ์‹ค์ œ ๋ณต์•ฝ์€ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ๋ฅผ ๋”ฐ๋ฅด์„ธ์š”.