LLDDWW commited on
Commit
e94e117
ยท
1 Parent(s): 92bb45b

sdfdsfads23333

Browse files
Files changed (1) hide show
  1. app.py +42 -21
app.py CHANGED
@@ -6,34 +6,31 @@ import gradio as gr
6
  import spaces
7
  import torch
8
  from PIL import Image
9
- from transformers import AutoModel, AutoProcessor, AutoTokenizer, AutoModelForCausalLM
10
 
11
- # Stage 1: OCR ๋ชจ๋ธ (๋ฌธ์„œ์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ)
12
- OCR_MODEL_ID = "ibm-granite/granite-docling-258M"
13
 
14
  # Stage 2: LLM ๋ชจ๋ธ (ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„ ์ถ”์ถœ)
15
  LLM_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
16
 
17
 
18
  def _load_ocr_model():
19
- """Granite Docling OCR ๋ชจ๋ธ ๋กœ๋“œ"""
20
- device = "cuda" if torch.cuda.is_available() else "cpu"
21
-
22
- model = AutoModel.from_pretrained(
23
- OCR_MODEL_ID,
24
- trust_remote_code=True
25
- ).to(device)
26
-
27
- processor = AutoProcessor.from_pretrained(
28
  OCR_MODEL_ID,
29
- trust_remote_code=True
 
 
 
30
  )
31
 
 
32
  return model, processor
33
 
34
 
35
  def _load_llm_model():
36
- """Llama 3.1 8B ๋ชจ๋ธ ๋กœ๋“œ (8bit ์–‘์žํ™”)"""
37
  model = AutoModelForCausalLM.from_pretrained(
38
  LLM_MODEL_ID,
39
  device_map="auto",
@@ -46,7 +43,7 @@ def _load_llm_model():
46
  return model, tokenizer
47
 
48
 
49
- print("๐Ÿ”„ Loading Granite Docling OCR model...")
50
  OCR_MODEL, OCR_PROCESSOR = _load_ocr_model()
51
  print("โœ… OCR model loaded!")
52
 
@@ -73,15 +70,39 @@ def _extract_json_block(text: str) -> Optional[str]:
73
 
74
 
75
  def extract_text_from_image(image: Image.Image) -> str:
76
- """Stage 1: Granite Docling์œผ๋กœ ์ด๋ฏธ์ง€์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ"""
77
  try:
78
- inputs = OCR_PROCESSOR(images=image, return_tensors="pt").to(OCR_MODEL.device)
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  with torch.no_grad():
81
- outputs = OCR_MODEL(**inputs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- extracted_text = OCR_PROCESSOR.batch_decode(outputs, skip_special_tokens=True)[0]
84
- return extracted_text.strip()
85
 
86
  except Exception as e:
87
  raise Exception(f"OCR ์˜ค๋ฅ˜: {str(e)}")
@@ -308,7 +329,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
308
  ---
309
 
310
  **โ„น๏ธ 2๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ**
311
- - **Stage 1**: Granite Docling (OCR) - ์ด๋ฏธ์ง€์—์„œ ๋ชจ๋“  ํ…์ŠคํŠธ ์ถ”์ถœ
312
  - **Stage 2**: Qwen2.5 7B (LLM) - ์ถ”์ถœ๋œ ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ์‹๋ณ„
313
 
314
  ์‹ค์ œ ๋ณต์•ฝ์€ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ๋ฅผ ๋”ฐ๋ฅด์„ธ์š”.
 
6
  import spaces
7
  import torch
8
  from PIL import Image
9
+ from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer, AutoModelForCausalLM
10
 
11
+ # Stage 1: OCR ๋ชจ๋ธ (Qwen2-VL๋กœ ๋ฌธ์„œ์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ)
12
+ OCR_MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
13
 
14
  # Stage 2: LLM ๋ชจ๋ธ (ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„ ์ถ”์ถœ)
15
  LLM_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
16
 
17
 
18
  def _load_ocr_model():
19
+ """Qwen2-VL OCR ๋ชจ๋ธ ๋กœ๋“œ"""
20
+ model = Qwen2VLForConditionalGeneration.from_pretrained(
 
 
 
 
 
 
 
21
  OCR_MODEL_ID,
22
+ device_map="auto",
23
+ load_in_8bit=True,
24
+ torch_dtype=torch.float16,
25
+ trust_remote_code=True,
26
  )
27
 
28
+ processor = AutoProcessor.from_pretrained(OCR_MODEL_ID, trust_remote_code=True)
29
  return model, processor
30
 
31
 
32
  def _load_llm_model():
33
+ """Qwen2.5 7B ๋ชจ๋ธ ๋กœ๋“œ (8bit ์–‘์žํ™”)"""
34
  model = AutoModelForCausalLM.from_pretrained(
35
  LLM_MODEL_ID,
36
  device_map="auto",
 
43
  return model, tokenizer
44
 
45
 
46
+ print("๐Ÿ”„ Loading Qwen2-VL OCR model...")
47
  OCR_MODEL, OCR_PROCESSOR = _load_ocr_model()
48
  print("โœ… OCR model loaded!")
49
 
 
70
 
71
 
72
  def extract_text_from_image(image: Image.Image) -> str:
73
+ """Stage 1: Qwen2-VL๋กœ ์ด๋ฏธ์ง€์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ (OCR)"""
74
  try:
75
+ messages = [
76
+ {
77
+ "role": "user",
78
+ "content": [
79
+ {"type": "text", "text": "์ด ์ด๋ฏธ์ง€์˜ ๋ชจ๋“  ํ…์ŠคํŠธ๋ฅผ ์ •ํ™•ํžˆ ์ถ”์ถœํ•ด์„œ ๊ทธ๋Œ€๋กœ ์ถœ๋ ฅํ•ด์ฃผ์„ธ์š”. OCR ๊ฒฐ๊ณผ๋งŒ ์ถœ๋ ฅํ•˜์„ธ์š”."},
80
+ {"type": "image"},
81
+ ],
82
+ }
83
+ ]
84
+
85
+ chat_text = OCR_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
86
+ inputs = OCR_PROCESSOR(text=[chat_text], images=[image], return_tensors="pt").to(OCR_MODEL.device)
87
 
88
  with torch.no_grad():
89
+ output_ids = OCR_MODEL.generate(
90
+ **inputs,
91
+ max_new_tokens=1024,
92
+ temperature=0.1, # ์ •ํ™•ํ•œ OCR์„ ์œ„ํ•ด ๋‚ฎ์€ temperature
93
+ do_sample=False, # ๊ฒฐ์ •์  ์ถœ๋ ฅ
94
+ )
95
+
96
+ output_text = OCR_PROCESSOR.batch_decode(output_ids, skip_special_tokens=False)[0]
97
+
98
+ # Extract assistant response
99
+ if "<|im_start|>assistant" in output_text:
100
+ extracted_text = output_text.split("<|im_start|>assistant")[-1]
101
+ extracted_text = extracted_text.replace("<|im_end|>", "").strip()
102
+ else:
103
+ extracted_text = output_text.strip()
104
 
105
+ return extracted_text
 
106
 
107
  except Exception as e:
108
  raise Exception(f"OCR ์˜ค๋ฅ˜: {str(e)}")
 
329
  ---
330
 
331
  **โ„น๏ธ 2๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ**
332
+ - **Stage 1**: Qwen2-VL 7B (OCR) - ์ด๋ฏธ์ง€์—์„œ ๋ชจ๋“  ํ…์ŠคํŠธ ์ถ”์ถœ
333
  - **Stage 2**: Qwen2.5 7B (LLM) - ์ถ”์ถœ๋œ ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ์‹๋ณ„
334
 
335
  ์‹ค์ œ ๋ณต์•ฝ์€ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ๋ฅผ ๋”ฐ๋ฅด์„ธ์š”.