LLDDWW commited on
Commit
64c57fb
ยท
1 Parent(s): 52bda02

sdfdsfads2333

Browse files
Files changed (1) hide show
  1. app.py +148 -85
app.py CHANGED
@@ -1,39 +1,58 @@
1
  import json
2
  import re
3
- from typing import List, Optional
4
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
  from PIL import Image
9
- from transformers import (
10
- Qwen2VLForConditionalGeneration,
11
- AutoProcessor,
12
- )
13
 
14
- # ์ตœ๊ณ  ํ’ˆ์งˆ ๊ณต๊ฐœ ๋ชจ๋ธ + 8๋น„ํŠธ ์–‘์žํ™” (ZeroGPU ์ตœ์ ํ™”)
15
- VL_MODEL_ID = "Qwen/Qwen2-VL-7B-Instruct"
16
 
 
 
17
 
18
- def _load_vl_model():
19
- """VL ๋ชจ๋ธ ๋กœ๋“œ - 8๋น„ํŠธ ์–‘์žํ™” + FP16"""
20
- device_map = "auto" if torch.cuda.is_available() else None
21
 
22
- model = Qwen2VLForConditionalGeneration.from_pretrained(
23
- VL_MODEL_ID,
24
- device_map=device_map,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  load_in_8bit=True,
26
  torch_dtype=torch.float16,
27
  trust_remote_code=True,
28
  )
29
 
30
- processor = AutoProcessor.from_pretrained(VL_MODEL_ID, trust_remote_code=True)
31
- return model, processor
32
 
33
 
34
- print("๐Ÿ”„ Loading Qwen2-VL-7B model...")
35
- VL_MODEL, VL_PROCESSOR = _load_vl_model()
36
- print("โœ… Model loaded successfully!")
 
 
 
 
37
 
38
 
39
  def _extract_assistant_content(decoded: str) -> str:
@@ -53,84 +72,120 @@ def _extract_json_block(text: str) -> Optional[str]:
53
  return match.group(0)
54
 
55
 
56
- @spaces.GPU(duration=120)
57
- def extract_medication_names(image: Image.Image) -> List[str]:
58
- """์ด๋ฏธ์ง€์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ์ถ”์ถœ"""
59
  try:
60
- instructions = """์ด ์‚ฌ์ง„ ์† ์•ฝ๋ด‰ํˆฌ/์ฒ˜๋ฐฉ์ „์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ๋ชจ๋‘ ์ฐพ์•„์„œ JSON ํ˜•์‹์œผ๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”."""
61
-
62
- schema = """{
63
- "medications": ["์•ฝ ์ด๋ฆ„ 1", "์•ฝ ์ด๋ฆ„ 2", "์•ฝ ์ด๋ฆ„ 3"]
64
- }"""
65
-
66
- messages = [
67
- {
68
- "role": "system",
69
- "content": "๋‹น์‹ ์€ ์•ฝ ์ด๋ฆ„์„ ์ •ํ™•ํžˆ ์ฝ๋Š” OCR ์ „๋ฌธ๊ฐ€์ž…๋‹ˆ๋‹ค. ์•ฝ๋ด‰ํˆฌ๋‚˜ ์ฒ˜๋ฐฉ์ „์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค.",
70
- },
71
- {
72
- "role": "user",
73
- "content": [
74
- {"type": "text", "text": instructions},
75
- {"type": "text", "text": schema},
76
- {"type": "image"},
77
- ],
78
- },
79
- ]
80
-
81
- chat_text = VL_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
82
- inputs = VL_PROCESSOR(text=[chat_text], images=[image], return_tensors="pt").to(VL_MODEL.device)
83
-
84
- output_ids = VL_MODEL.generate(
85
- **inputs,
86
- max_new_tokens=1024,
87
- temperature=0.2, # ๋งค์šฐ ์ •ํ™•ํ•˜๊ฒŒ
88
- top_p=0.85,
89
- do_sample=True,
90
- )
91
-
92
- decoded = VL_PROCESSOR.batch_decode(output_ids, skip_special_tokens=False)[0]
93
- assistant_text = _extract_assistant_content(decoded)
94
-
95
- # JSON ํŒŒ์‹ฑ
96
- json_block = _extract_json_block(assistant_text)
97
- if json_block:
98
- data = json.loads(json_block)
99
- meds = data.get("medications", [])
100
- if isinstance(meds, list):
101
- return [str(m).strip() for m in meds if str(m).strip()]
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  return ["์•ฝ ์ด๋ฆ„์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค."]
104
 
105
  except Exception as e:
106
- return [f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
 
109
- def format_medication_list(medications: List[str]) -> str:
110
- """์•ฝ ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ๋ฅผ ๋งˆํฌ๋‹ค์šด์œผ๋กœ ํฌ๋งท"""
111
- if not medications or medications[0].startswith("์˜ค๋ฅ˜") or medications[0].startswith("์•ฝ ์ด๋ฆ„์„ ์ฐพ์ง€"):
112
- return f"### โš ๏ธ {medications[0] if medications else '์•ฝ ์ด๋ฆ„์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.'}"
113
 
114
- output = f"### ๐Ÿ’Š ๊ฒ€์ถœ๋œ ์•ฝ๋ฌผ ({len(medications)}๊ฐœ)\n\n"
115
- for idx, med_name in enumerate(medications, 1):
116
- output += f"{idx}. **{med_name}**\n"
 
 
 
 
117
 
118
- return output
119
 
120
 
121
  def run_analysis(image: Optional[Image.Image], progress=gr.Progress()):
122
- """๋ฉ”์ธ ๋ถ„์„ ํŒŒ์ดํ”„๋ผ์ธ"""
123
  if image is None:
124
- return "๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ๋‚˜ ์ฒ˜๋ฐฉ์ „ ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”."
 
 
 
125
 
126
- progress(0.3, desc="๐Ÿ” ์ด๋ฏธ์ง€ ๋ถ„์„ ์ค‘...")
127
- medications = extract_medication_names(image)
128
 
129
  progress(0.9, desc="๐Ÿ“ ๊ฒฐ๊ณผ ์ •๋ฆฌ ์ค‘...")
130
- result_md = format_medication_list(medications)
131
 
132
  progress(1.0, desc="โœ… ์™„๋ฃŒ!")
133
- return result_md
134
 
135
 
136
  # ์‹ฌํ”Œํ•œ CSS
@@ -228,23 +283,31 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
228
  with gr.Column(elem_classes=["upload-section"]):
229
  gr.Markdown("### ๐Ÿ“ธ ์‚ฌ์ง„ ์—…๋กœ๋“œ")
230
  image_input = gr.Image(type="pil", label="์•ฝ๋ด‰ํˆฌ ๋˜๋Š” ์ฒ˜๋ฐฉ์ „ ์‚ฌ์ง„", height=350)
231
- analyze_button = gr.Button("๐Ÿ” ์•ฝ ์ด๋ฆ„ ์ถ”์ถœ", elem_classes=["analyze-btn"], size="lg")
232
 
233
- with gr.Column(elem_classes=["result-section"]):
234
- gr.Markdown("### ๐Ÿ“‹ ์ถ”์ถœ ๊ฒฐ๊ณผ")
235
- result_output = gr.Markdown("๋ถ„์„์„ ์‹œ์ž‘ํ•˜๋ฉด ์—ฌ๊ธฐ์— ์•ฝ ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ๊ฐ€ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.")
 
 
 
 
 
236
 
237
  analyze_button.click(
238
  run_analysis,
239
  inputs=image_input,
240
- outputs=result_output,
241
  )
242
 
243
  gr.Markdown("""
244
  ---
245
 
246
- **โ„น๏ธ ์ฐธ๊ณ ์‚ฌํ•ญ**
247
- ์ด ๋„๊ตฌ๋Š” OCR ๊ธฐ๋ฐ˜์œผ๋กœ ์•ฝ ์ด๋ฆ„๋งŒ ์ถ”์ถœํ•ฉ๋‹ˆ๋‹ค. ์‹ค์ œ ๋ณต์•ฝ์€ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ๋ฅผ ๋”ฐ๋ฅด์„ธ์š”.
 
 
 
248
  """)
249
 
250
  if __name__ == "__main__":
 
1
  import json
2
  import re
3
+ from typing import List, Optional, Tuple
4
 
5
  import gradio as gr
6
  import spaces
7
  import torch
8
  from PIL import Image
9
+ from transformers import AutoModel, AutoProcessor, AutoTokenizer, AutoModelForCausalLM
 
 
 
10
 
11
+ # Stage 1: OCR ๋ชจ๋ธ (๋ฌธ์„œ์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ)
12
+ OCR_MODEL_ID = "ibm-granite/granite-docling-258M"
13
 
14
+ # Stage 2: LLM ๋ชจ๋ธ (ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„ ์ถ”์ถœ)
15
+ LLM_MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
16
 
 
 
 
17
 
18
+ def _load_ocr_model():
19
+ """Granite Docling OCR ๋ชจ๋ธ ๋กœ๋“œ"""
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+ model = AutoModel.from_pretrained(
23
+ OCR_MODEL_ID,
24
+ trust_remote_code=True
25
+ ).to(device)
26
+
27
+ processor = AutoProcessor.from_pretrained(
28
+ OCR_MODEL_ID,
29
+ trust_remote_code=True
30
+ )
31
+
32
+ return model, processor
33
+
34
+
35
+ def _load_llm_model():
36
+ """Llama 3.1 8B ๋ชจ๋ธ ๋กœ๋“œ (8bit ์–‘์žํ™”)"""
37
+ model = AutoModelForCausalLM.from_pretrained(
38
+ LLM_MODEL_ID,
39
+ device_map="auto",
40
  load_in_8bit=True,
41
  torch_dtype=torch.float16,
42
  trust_remote_code=True,
43
  )
44
 
45
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID, trust_remote_code=True)
46
+ return model, tokenizer
47
 
48
 
49
+ print("๐Ÿ”„ Loading Granite Docling OCR model...")
50
+ OCR_MODEL, OCR_PROCESSOR = _load_ocr_model()
51
+ print("โœ… OCR model loaded!")
52
+
53
+ print("๐Ÿ”„ Loading Llama-3.1-8B-Instruct...")
54
+ LLM_MODEL, LLM_TOKENIZER = _load_llm_model()
55
+ print("โœ… LLM model loaded!")
56
 
57
 
58
  def _extract_assistant_content(decoded: str) -> str:
 
72
  return match.group(0)
73
 
74
 
75
+ def extract_text_from_image(image: Image.Image) -> str:
76
+ """Stage 1: Granite Docling์œผ๋กœ ์ด๋ฏธ์ง€์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ"""
 
77
  try:
78
+ inputs = OCR_PROCESSOR(images=image, return_tensors="pt").to(OCR_MODEL.device)
79
+
80
+ with torch.no_grad():
81
+ outputs = OCR_MODEL(**inputs)
82
+
83
+ extracted_text = OCR_PROCESSOR.batch_decode(outputs, skip_special_tokens=True)[0]
84
+ return extracted_text.strip()
85
+
86
+ except Exception as e:
87
+ raise Exception(f"OCR ์˜ค๋ฅ˜: {str(e)}")
88
+
89
+
90
+ def extract_medications_from_text(text: str) -> List[str]:
91
+ """Stage 2: Llama 3.1๋กœ ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ์ถ”์ถœ"""
92
+ try:
93
+ prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
94
+
95
+ You are a medical text analyzer. Extract only medication names from the given text and return them as a JSON array.
96
+ Return ONLY valid JSON format: {{"medications": ["name1", "name2"]}}
97
+ <|eot_id|><|start_header_id|>user<|end_header_id|>
98
+
99
+ Extract all medication names from this text:
100
+
101
+ {text}
102
+
103
+ Return only the JSON array of medication names.<|eot_id|><|start_header_id|>assistant<|end_header_id|>
104
+
105
+ """
106
+
107
+ inputs = LLM_TOKENIZER(prompt, return_tensors="pt").to(LLM_MODEL.device)
108
+
109
+ with torch.no_grad():
110
+ outputs = LLM_MODEL.generate(
111
+ **inputs,
112
+ max_new_tokens=512,
113
+ temperature=0.3,
114
+ top_p=0.9,
115
+ do_sample=True,
116
+ pad_token_id=LLM_TOKENIZER.eos_token_id,
117
+ )
118
+
119
+ response = LLM_TOKENIZER.decode(outputs[0], skip_special_tokens=True)
120
+
121
+ # Extract assistant response
122
+ if "<|start_header_id|>assistant<|end_header_id|>" in response:
123
+ response = response.split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip()
124
+
125
+ # Parse JSON
126
+ json_match = re.search(r'\{.*?\}', response, re.DOTALL)
127
+ if json_match:
128
+ data = json.loads(json_match.group(0))
129
+ medications = data.get("medications", [])
130
+ if isinstance(medications, list) and medications:
131
+ return [str(m).strip() for m in medications if str(m).strip()]
132
 
133
  return ["์•ฝ ์ด๋ฆ„์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค."]
134
 
135
  except Exception as e:
136
+ raise Exception(f"LLM ๋ถ„์„ ์˜ค๋ฅ˜: {str(e)}")
137
+
138
+
139
+ @spaces.GPU(duration=120)
140
+ def extract_medication_names(image: Image.Image) -> Tuple[str, List[str]]:
141
+ """2๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ: OCR โ†’ LLM ๋ถ„์„"""
142
+ try:
143
+ # Stage 1: OCR๋กœ ํ…์ŠคํŠธ ์ถ”์ถœ
144
+ extracted_text = extract_text_from_image(image)
145
+
146
+ if not extracted_text:
147
+ return "", ["ํ…์ŠคํŠธ๋ฅผ ์ถ”์ถœํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค."]
148
+
149
+ # Stage 2: LLM์œผ๋กœ ์•ฝ ์ด๋ฆ„ ์ถ”์ถœ
150
+ medications = extract_medications_from_text(extracted_text)
151
+
152
+ return extracted_text, medications
153
+
154
+ except Exception as e:
155
+ return "", [f"์˜ค๋ฅ˜ ๋ฐœ์ƒ: {str(e)}"]
156
 
157
 
158
+ def format_results(extracted_text: str, medications: List[str]) -> Tuple[str, str]:
159
+ """๊ฒฐ๊ณผ๋ฅผ ํฌ๋งทํŒ…"""
160
+ # ์ถ”์ถœ๋œ ์ „์ฒด ํ…์ŠคํŠธ
161
+ text_output = f"### ๐Ÿ“„ ์ถ”์ถœ๋œ ํ…์ŠคํŠธ\n\n```\n{extracted_text}\n```"
162
 
163
+ # ์•ฝ ์ด๋ฆ„ ๋ฆฌ์ŠคํŠธ
164
+ if not medications or medications[0].startswith("์˜ค๋ฅ˜") or medications[0].startswith("์•ฝ ์ด๋ฆ„์„ ์ฐพ์ง€") or medications[0].startswith("ํ…์ŠคํŠธ๋ฅผ"):
165
+ med_output = f"### โš ๏ธ {medications[0] if medications else '์•ฝ ์ด๋ฆ„์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.'}"
166
+ else:
167
+ med_output = f"### ๐Ÿ’Š ๊ฒ€์ถœ๋œ ์•ฝ๋ฌผ ({len(medications)}๊ฐœ)\n\n"
168
+ for idx, med_name in enumerate(medications, 1):
169
+ med_output += f"{idx}. **{med_name}**\n"
170
 
171
+ return text_output, med_output
172
 
173
 
174
  def run_analysis(image: Optional[Image.Image], progress=gr.Progress()):
175
+ """๋ฉ”์ธ ๋ถ„์„ ํŒŒ์ดํ”„๋ผ์ธ: OCR โ†’ LLM"""
176
  if image is None:
177
+ return "๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ๋‚˜ ์ฒ˜๋ฐฉ์ „ ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•ด์ฃผ์„ธ์š”.", ""
178
+
179
+ progress(0.2, desc="๐Ÿ“ธ Stage 1: OCR ํ…์ŠคํŠธ ์ถ”์ถœ ์ค‘...")
180
+ progress(0.6, desc="๐Ÿค– Stage 2: LLM ์•ฝ๋ฌผ ๋ถ„์„ ์ค‘...")
181
 
182
+ extracted_text, medications = extract_medication_names(image)
 
183
 
184
  progress(0.9, desc="๐Ÿ“ ๊ฒฐ๊ณผ ์ •๋ฆฌ ์ค‘...")
185
+ text_output, med_output = format_results(extracted_text, medications)
186
 
187
  progress(1.0, desc="โœ… ์™„๋ฃŒ!")
188
+ return text_output, med_output
189
 
190
 
191
  # ์‹ฌํ”Œํ•œ CSS
 
283
  with gr.Column(elem_classes=["upload-section"]):
284
  gr.Markdown("### ๐Ÿ“ธ ์‚ฌ์ง„ ์—…๋กœ๋“œ")
285
  image_input = gr.Image(type="pil", label="์•ฝ๋ด‰ํˆฌ ๋˜๋Š” ์ฒ˜๋ฐฉ์ „ ์‚ฌ์ง„", height=350)
286
+ analyze_button = gr.Button("๐Ÿ” 2๋‹จ๊ณ„ ๋ถ„์„ ์‹œ์ž‘ (OCR โ†’ LLM)", elem_classes=["analyze-btn"], size="lg")
287
 
288
+ with gr.Row():
289
+ with gr.Column(elem_classes=["result-section"]):
290
+ gr.Markdown("### ๐Ÿ“‹ Stage 1: OCR ๊ฒฐ๊ณผ")
291
+ text_output = gr.Markdown("OCR๋กœ ์ถ”์ถœ๋œ ์ „์ฒด ํ…์ŠคํŠธ๊ฐ€ ์—ฌ๊ธฐ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.")
292
+
293
+ with gr.Column(elem_classes=["result-section"]):
294
+ gr.Markdown("### ๐Ÿ“‹ Stage 2: LLM ๋ถ„์„ ๊ฒฐ๊ณผ")
295
+ med_output = gr.Markdown("LLM์ด ๋ถ„์„ํ•œ ์•ฝ๋ฌผ ๋ฆฌ์ŠคํŠธ๊ฐ€ ์—ฌ๊ธฐ ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.")
296
 
297
  analyze_button.click(
298
  run_analysis,
299
  inputs=image_input,
300
+ outputs=[text_output, med_output],
301
  )
302
 
303
  gr.Markdown("""
304
  ---
305
 
306
+ **โ„น๏ธ 2๋‹จ๊ณ„ ํŒŒ์ดํ”„๋ผ์ธ**
307
+ - **Stage 1**: Granite Docling (OCR) - ์ด๋ฏธ์ง€์—์„œ ๋ชจ๋“  ํ…์ŠคํŠธ ์ถ”์ถœ
308
+ - **Stage 2**: Llama 3.1 8B (LLM) - ์ถ”์ถœ๋œ ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„๋งŒ ์‹๋ณ„
309
+
310
+ ์‹ค์ œ ๋ณต์•ฝ์€ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ๋ฅผ ๋”ฐ๋ฅด์„ธ์š”.
311
  """)
312
 
313
  if __name__ == "__main__":