LLDDWW Claude commited on
Commit
d5aff0d
ยท
1 Parent(s): 0e6a905

feat: use Gemma-2-2B for medical analysis

Browse files

- Separate OCR (Qwen2.5-VL-7B) and medical analysis (Gemma-2-2B)
- Add comprehensive medication info: name, effects, side effects, usage, precautions
- Enhanced prompt for easy-to-understand explanations
- User-friendly format for elderly and children

๐Ÿค– Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

Files changed (1) hide show
  1. app.py +55 -45
app.py CHANGED
@@ -7,11 +7,14 @@ import gradio as gr
7
  import spaces
8
  import torch
9
  from PIL import Image
10
- from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor
11
  from qwen_vl_utils import process_vision_info
12
 
13
- # Qwen2.5-VL ๋ชจ๋ธ ID
14
- MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
 
 
 
15
 
16
 
17
  def _extract_assistant_content(decoded: str) -> str:
@@ -35,15 +38,14 @@ def _extract_json_block(text: str) -> Optional[str]:
35
  def analyze_medication_image(image: Image.Image) -> Tuple[str, str]:
36
  """์ด๋ฏธ์ง€์—์„œ OCR ์ถ”์ถœ ํ›„ ์•ฝ ์ •๋ณด ๋ถ„์„"""
37
  try:
38
- # Qwen2.5-VL ๋ชจ๋ธ ๋กœ๋“œ
39
- model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
40
- MODEL_ID,
41
  torch_dtype="auto",
42
  device_map="auto"
43
  )
44
- processor = AutoProcessor.from_pretrained(MODEL_ID)
45
 
46
- # Step 1: OCR - ์ด๋ฏธ์ง€์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ
47
  ocr_messages = [
48
  {
49
  "role": "user",
@@ -54,72 +56,79 @@ def analyze_medication_image(image: Image.Image) -> Tuple[str, str]:
54
  }
55
  ]
56
 
57
- text = processor.apply_chat_template(ocr_messages, tokenize=False, add_generation_prompt=True)
58
  image_inputs, video_inputs = process_vision_info(ocr_messages)
59
- inputs = processor(
60
  text=[text],
61
  images=image_inputs,
62
  videos=video_inputs,
63
  padding=True,
64
  return_tensors="pt",
65
  )
66
- inputs = inputs.to(model.device)
67
 
68
  with torch.no_grad():
69
- generated_ids = model.generate(**inputs, max_new_tokens=2048)
70
 
71
  generated_ids_trimmed = [
72
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
73
  ]
74
 
75
- ocr_text = processor.batch_decode(
76
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
77
  )[0]
78
 
79
  if not ocr_text or ocr_text.strip() == "":
80
  return "ํ…์ŠคํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.", ""
81
 
82
- # Step 2: ์•ฝ ์ •๋ณด ๋ถ„์„ - OCR ํ…์ŠคํŠธ๋ฅผ LLM์—๊ฒŒ ์ „๋‹ฌ
83
- analysis_messages = [
84
- {
85
- "role": "user",
86
- "content": [
87
- {"type": "text", "text": f"""๋‹ค์Œ์€ ์•ฝ ๋ด‰ํˆฌ๋‚˜ ์ฒ˜๋ฐฉ์ „์—์„œ ์ถ”์ถœํ•œ ํ…์ŠคํŠธ์ž…๋‹ˆ๋‹ค:
 
 
 
88
 
89
  {ocr_text}
90
 
91
- ์œ„ ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„์„ ์ฐพ์•„์„œ, ๊ฐ ์•ฝ์— ๋Œ€ํ•ด ๋‹ค์Œ ์ •๋ณด๋ฅผ **๋…ธ์ธ๊ณผ ์–ด๋ฆฐ์ด ๋ชจ๋‘ ์‰ฝ๊ฒŒ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก** ์žฌ๋ฏธ์žˆ๊ณ  ์นœ๊ทผํ•˜๊ฒŒ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”:
92
 
93
- 1. **์•ฝ ์ด๋ฆ„**: ์ •ํ™•ํ•œ ์•ฝ ์ด๋ฆ„
94
- 2. **ํšจ๋Šฅ**: ์ด ์•ฝ์ด ๋ฌด์—‡์„ ์น˜๋ฃŒํ•˜๊ณ  ์–ด๋–ป๊ฒŒ ๋„์›€์ด ๋˜๋Š”์ง€
95
- 3. **๋ถ€์ž‘์šฉ**: ์ฃผ์˜ํ•ด์•ผ ํ•  ๋ถ€์ž‘์šฉ๋“ค
96
 
97
- ๊ฐ ์•ฝ๋งˆ๋‹ค ์ด๋ชจ์ง€๋ฅผ ์‚ฌ์šฉํ•˜๊ณ , ์‰ฌ์šด ๋‹จ์–ด๋กœ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”. ํ• ๋จธ๋‹ˆ ํ• ์•„๋ฒ„์ง€๋‚˜ ์ดˆ๋“ฑํ•™์ƒ๋„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๊ฒŒ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”.
98
- ๋งˆํฌ๋‹ค์šด ํ˜•์‹์œผ๋กœ ์ž‘์„ฑํ•ด์ฃผ์„ธ์š”."""},
99
- ],
100
- }
101
- ]
102
 
103
- text = processor.apply_chat_template(analysis_messages, tokenize=False, add_generation_prompt=True)
104
- inputs = processor(
105
- text=[text],
106
- images=None,
107
- videos=None,
108
- padding=True,
109
- return_tensors="pt",
110
- )
111
- inputs = inputs.to(model.device)
112
 
113
- with torch.no_grad():
114
- generated_ids = model.generate(**inputs, max_new_tokens=3072, temperature=0.7)
115
 
116
- generated_ids_trimmed = [
117
- out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
118
  ]
119
 
120
- analysis_text = processor.batch_decode(
121
- generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
122
- )[0]
 
 
 
 
 
 
 
 
 
 
123
 
124
  return ocr_text.strip(), analysis_text.strip()
125
 
@@ -360,7 +369,8 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
360
  - AI๊ฐ€ ์ƒ์„ฑํ•œ ์ •๋ณด์ด๋ฏ€๋กœ ์ •ํ™•ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
361
 
362
  **๐Ÿค– ๊ธฐ์ˆ  ์Šคํƒ**
363
- - Qwen2.5-VL-7B-Instruct (OCR + ์•ฝ ์ •๋ณด ๋ถ„์„)
 
364
  """)
365
 
366
  if __name__ == "__main__":
 
7
  import spaces
8
  import torch
9
  from PIL import Image
10
+ from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor, AutoTokenizer, AutoModelForCausalLM
11
  from qwen_vl_utils import process_vision_info
12
 
13
+ # OCR ๋ชจ๋ธ ID
14
+ OCR_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
15
+
16
+ # ์•ฝ ์ •๋ณด ๋ถ„์„ ๋ชจ๋ธ ID (์˜๋ฃŒ ์ „๋ฌธ)
17
+ MED_MODEL_ID = "google/gemma-2-2b-it"
18
 
19
 
20
  def _extract_assistant_content(decoded: str) -> str:
 
38
  def analyze_medication_image(image: Image.Image) -> Tuple[str, str]:
39
  """์ด๋ฏธ์ง€์—์„œ OCR ์ถ”์ถœ ํ›„ ์•ฝ ์ •๋ณด ๋ถ„์„"""
40
  try:
41
+ # Step 1: OCR - Qwen2.5-VL๋กœ ์ด๋ฏธ์ง€์—์„œ ํ…์ŠคํŠธ ์ถ”์ถœ
42
+ ocr_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
43
+ OCR_MODEL_ID,
44
  torch_dtype="auto",
45
  device_map="auto"
46
  )
47
+ ocr_processor = AutoProcessor.from_pretrained(OCR_MODEL_ID)
48
 
 
49
  ocr_messages = [
50
  {
51
  "role": "user",
 
56
  }
57
  ]
58
 
59
+ text = ocr_processor.apply_chat_template(ocr_messages, tokenize=False, add_generation_prompt=True)
60
  image_inputs, video_inputs = process_vision_info(ocr_messages)
61
+ inputs = ocr_processor(
62
  text=[text],
63
  images=image_inputs,
64
  videos=video_inputs,
65
  padding=True,
66
  return_tensors="pt",
67
  )
68
+ inputs = inputs.to(ocr_model.device)
69
 
70
  with torch.no_grad():
71
+ generated_ids = ocr_model.generate(**inputs, max_new_tokens=2048)
72
 
73
  generated_ids_trimmed = [
74
  out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
75
  ]
76
 
77
+ ocr_text = ocr_processor.batch_decode(
78
  generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
79
  )[0]
80
 
81
  if not ocr_text or ocr_text.strip() == "":
82
  return "ํ…์ŠคํŠธ๋ฅผ ์ฐพ์„ ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค.", ""
83
 
84
+ # Step 2: ์•ฝ ์ •๋ณด ๋ถ„์„ - Gemma-2๋กœ ์˜๋ฃŒ ์ •๋ณด ์ œ๊ณต
85
+ med_model = AutoModelForCausalLM.from_pretrained(
86
+ MED_MODEL_ID,
87
+ torch_dtype=torch.bfloat16,
88
+ device_map="auto"
89
+ )
90
+ med_tokenizer = AutoTokenizer.from_pretrained(MED_MODEL_ID)
91
+
92
+ analysis_prompt = f"""๋‹ค์Œ์€ ์•ฝ ๋ด‰ํˆฌ๋‚˜ ์ฒ˜๋ฐฉ์ „์—์„œ ์ถ”์ถœํ•œ ํ…์ŠคํŠธ์ž…๋‹ˆ๋‹ค:
93
 
94
  {ocr_text}
95
 
96
+ ์œ„ ํ…์ŠคํŠธ์—์„œ ์•ฝ ์ด๋ฆ„์„ ์ฐพ์•„์„œ, ๊ฐ ์•ฝ์— ๋Œ€ํ•ด **๋…ธ์ธ๊ณผ ์–ด๋ฆฐ์ด ๋ชจ๋‘ ์‰ฝ๊ฒŒ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋„๋ก** ์žฌ๋ฏธ์žˆ๊ณ  ์นœ๊ทผํ•˜๊ฒŒ ์„ค๋ช…ํ•ด์ฃผ์„ธ์š”:
97
 
98
+ ๐Ÿ“‹ **๊ฐ ์•ฝ๋งˆ๋‹ค ๋‹ค์Œ ์ •๋ณด๋ฅผ ํฌํ•จํ•ด์ฃผ์„ธ์š”:**
 
 
99
 
100
+ 1. ๐Ÿ’Š **์•ฝ ์ด๋ฆ„**: ์ •ํ™•ํ•œ ์•ฝ ์ด๋ฆ„
101
+ 2. ๐ŸŽฏ **ํšจ๋Šฅ**: ์ด ์•ฝ์ด ๋ฌด์—‡์„ ์น˜๋ฃŒํ•˜๊ณ  ์–ด๋–ป๊ฒŒ ๋„์›€์ด ๋˜๋Š”์ง€
102
+ 3. โš ๏ธ **๋ถ€์ž‘์šฉ**: ์ฃผ์˜ํ•ด์•ผ ํ•  ๋ถ€์ž‘์šฉ๋“ค
103
+ 4. ๐Ÿ’ก **๋ณต์šฉ ๋ฐฉ๋ฒ•**: ์–ธ์ œ, ์–ด๋–ป๊ฒŒ ๋จน์–ด์•ผ ํ•˜๋Š”์ง€ (์‹์ „/์‹ํ›„, ํ•˜๋ฃจ ๋ช‡ ๋ฒˆ ๋“ฑ)
104
+ 5. ๐Ÿšซ **์ฃผ์˜์‚ฌํ•ญ**: ์ด ์•ฝ๊ณผ ํ•จ๊ป˜ ๋จน์œผ๋ฉด ์•ˆ ๋˜๋Š” ๊ฒƒ๋“ค (์Œ์‹, ๋‹ค๋ฅธ ์•ฝ ๋“ฑ)
105
 
106
+ **์Šคํƒ€์ผ ๊ฐ€์ด๋“œ:**
107
+ - ์ด๋ชจ์ง€๋ฅผ ์ ๊ทน ํ™œ์šฉํ•˜์—ฌ ์žฌ๋ฏธ์žˆ๊ฒŒ ์ž‘์„ฑ
108
+ - ํ• ๋จธ๋‹ˆ ํ• ์•„๋ฒ„์ง€๋‚˜ ์ดˆ๋“ฑํ•™์ƒ๋„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” ์‰ฌ์šด ๋‹จ์–ด ์‚ฌ์šฉ
109
+ - ๊ฐ ์•ฝ๋งˆ๋‹ค ๊ตฌ๋ถ„์„ ์œผ๋กœ ๊ตฌ๋ถ„
110
+ - ์นœ๊ทผํ•˜๊ณ  ๋”ฐ๋œปํ•œ ๋งํˆฌ ์‚ฌ์šฉ
111
+ - ๋งˆํฌ๋‹ค์šด ํ˜•์‹์œผ๋กœ ์ž‘์„ฑ
 
 
 
112
 
113
+ ์‹œ์ž‘ํ•ด์ฃผ์„ธ์š”!"""
 
114
 
115
+ messages = [
116
+ {"role": "user", "content": analysis_prompt}
117
  ]
118
 
119
+ input_text = med_tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
120
+ inputs = med_tokenizer(input_text, return_tensors="pt").to(med_model.device)
121
+
122
+ with torch.no_grad():
123
+ outputs = med_model.generate(
124
+ **inputs,
125
+ max_new_tokens=3072,
126
+ temperature=0.7,
127
+ top_p=0.9,
128
+ do_sample=True
129
+ )
130
+
131
+ analysis_text = med_tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
132
 
133
  return ocr_text.strip(), analysis_text.strip()
134
 
 
369
  - AI๊ฐ€ ์ƒ์„ฑํ•œ ์ •๋ณด์ด๋ฏ€๋กœ ์ •ํ™•ํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
370
 
371
  **๐Ÿค– ๊ธฐ์ˆ  ์Šคํƒ**
372
+ - Qwen2.5-VL-7B-Instruct (OCR ํ…์ŠคํŠธ ์ถ”์ถœ)
373
+ - Google Gemma-2-2B-IT (์˜๋ฃŒ ์ •๋ณด ๋ถ„์„ ๋ฐ ์„ค๋ช…)
374
  """)
375
 
376
  if __name__ == "__main__":