LLDDWW commited on
Commit
e53f54d
ยท
1 Parent(s): dbf7d32

feat: add qwen vl narratives and cartoon generation

Browse files
Files changed (2) hide show
  1. app.py +169 -73
  2. requirements.txt +2 -1
app.py CHANGED
@@ -3,12 +3,20 @@ import re
3
  from typing import Any, Dict, List, Optional
4
 
5
  import gradio as gr
6
- import torch
7
  import spaces
 
 
8
  from PIL import Image, ImageDraw
9
- from transformers import AutoModelForVision2Seq, AutoProcessor
 
 
 
 
 
10
 
11
  VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
 
 
12
 
13
 
14
  def _load_vl_model():
@@ -29,6 +37,39 @@ def _load_vl_model():
29
  VL_MODEL, VL_PROCESSOR = _load_vl_model()
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def _extract_assistant_content(decoded: str) -> str:
33
  if "<|im_start|>assistant" in decoded:
34
  content = decoded.split("<|im_start|>assistant")[-1]
@@ -44,40 +85,34 @@ def _extract_json_block(text: str) -> Optional[str]:
44
  return match.group(0)
45
 
46
 
47
- def _sanitize_medication(item: Dict[str, Any]) -> Dict[str, Any]:
48
- def _as_str(value: Any) -> str:
49
- if value is None:
50
- return ""
51
- return str(value).strip()
 
52
 
53
- name = _as_str(item.get("name"))
54
- dose = _as_str(item.get("dose_per_intake"))
 
 
55
 
56
  times = item.get("times_per_day")
57
  if isinstance(times, (int, float)):
58
  times_str = str(int(times)) if float(times).is_integer() else str(times)
59
  else:
60
- times_str = _as_str(times)
61
-
62
- time_slots_raw = item.get("time_slots")
63
- if isinstance(time_slots_raw, (list, tuple)):
64
- time_slots = [str(t).strip() for t in time_slots_raw if str(t).strip()]
65
- elif isinstance(time_slots_raw, str):
66
- slots = [s.strip() for s in re.split(r"[,;]\s*", time_slots_raw) if s.strip()]
67
- time_slots = slots
68
- else:
69
- time_slots = []
70
 
71
  return {
72
- "name": name,
73
- "dose_per_intake": dose,
74
  "times_per_day": times_str,
75
- "time_slots": time_slots,
76
- "description": _as_str(item.get("description")),
77
- "usage_example": _as_str(item.get("usage_example")),
78
- "dosage_example": _as_str(item.get("dosage_example")),
79
- "side_effects": _as_str(item.get("side_effects")),
80
- "warnings": _as_str(item.get("warnings")),
81
  }
82
 
83
 
@@ -87,7 +122,7 @@ def _parse_vl_response(text: str) -> Dict[str, Any]:
87
  return {
88
  "raw_text": "",
89
  "medications": [],
90
- "warnings": ["LLM ์‘๋‹ต์—์„œ JSON์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.", text.strip()],
91
  }
92
  try:
93
  data = json.loads(json_block)
@@ -95,11 +130,9 @@ def _parse_vl_response(text: str) -> Dict[str, Any]:
95
  return {
96
  "raw_text": "",
97
  "medications": [],
98
- "warnings": ["LLM JSON ํŒŒ์‹ฑ ์‹คํŒจ", text.strip()],
99
  }
100
 
101
- raw_text = str(data.get("raw_text", "")).strip()
102
-
103
  meds_raw = data.get("medications") or []
104
  medications: List[Dict[str, Any]] = []
105
  if isinstance(meds_raw, list):
@@ -116,7 +149,7 @@ def _parse_vl_response(text: str) -> Dict[str, Any]:
116
  warnings = []
117
 
118
  return {
119
- "raw_text": raw_text,
120
  "medications": medications,
121
  "warnings": warnings,
122
  }
@@ -135,27 +168,26 @@ def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]:
135
  " {\n"
136
  " \"name\": \"์•ฝ ์ด๋ฆ„\",\n"
137
  " \"dose_per_intake\": \"1ํšŒ ์šฉ๋Ÿ‰ (์˜ˆ: 1์ •, 5mL)\",\n"
138
- " \"times_per_day\": \"ํ•˜๋ฃจ ๋ณต์šฉ ํšŸ์ˆ˜ (๋ชจ๋ฅด๋ฉด ๋นˆ ๋ฌธ์ž์—ด)\",\n"
139
  " \"time_slots\": [\"๋ณต์šฉ ์‹œ๊ฐ„๋Œ€\"],\n"
140
- " \"description\": \"์–ด๋–ค ์•ฝ์ธ์ง€ ํ•œ ์ค„ ์„ค๋ช…\",\n"
141
- " \"usage_example\": \"์–ธ์ œ ๋ณต์šฉํ•˜๋ฉด ์ข‹์€์ง€ ์˜ˆ์‹œ\",\n"
142
- " \"dosage_example\": \"๋ณต์šฉ ๋ฐฉ๋ฒ• ์˜ˆ์‹œ(์˜ˆ: ์‹ํ›„ 30๋ถ„, 1ํšŒ 1์ •)\",\n"
143
- " \"side_effects\": \"์ฃผ์š” ๋ถ€์ž‘์šฉ ๋˜๋Š” ์ฃผ์˜์‚ฌํ•ญ\",\n"
144
- " \"warnings\": \"์ถ”๊ฐ€ ์ฃผ์˜ ๋ฌธ๊ตฌ\"\n"
145
  " }\n"
146
  " ],\n"
147
- " \"warnings\": [\"์ „์ฒด์ ์ธ ๊ฒฝ๊ณ  ๋ฌธ๊ตฌ\"]\n"
148
  "}"
149
  )
150
  user_prompt = (
151
- "์œ„ JSON ์Šคํ‚ค๋งˆ๋ฅผ ๊ทธ๋Œ€๋กœ ๋”ฐ๋ฅด์„ธ์š”. ๋นˆ ๊ฐ’์€ ๋นˆ ๋ฌธ์ž์—ด๋กœ ๋‘ก๋‹ˆ๋‹ค. "
152
- "๋ชจ๋“  ๊ฐ’์€ ํ•œ๊ตญ์–ด๋กœ ์ž‘์„ฑํ•˜๊ณ , ์ค‘ํ•™์ƒ๋„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” ๋งํˆฌ๋กœ ์„ค๋ช…ํ•˜์„ธ์š”."
153
  )
154
 
155
  messages = [
156
  {
157
  "role": "system",
158
- "content": "๋‹น์‹ ์€ ์•ฝ์‚ฌ ์„ ์ƒ๋‹˜์œผ๋กœ์„œ ์•ฝ๋ด‰ํˆฌ ์ด๋ฏธ์ง€๋ฅผ ํ•ด์„ํ•˜๊ณ  ์นœ์ ˆํ•˜๊ฒŒ ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค.",
159
  },
160
  {
161
  "role": "user",
@@ -169,11 +201,7 @@ def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]:
169
  ]
170
 
171
  chat_text = VL_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
172
- inputs = VL_PROCESSOR(
173
- text=[chat_text],
174
- images=[image],
175
- return_tensors="pt",
176
- ).to(VL_MODEL.device)
177
 
178
  output_ids = VL_MODEL.generate(
179
  **inputs,
@@ -188,6 +216,85 @@ def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]:
188
  return _parse_vl_response(assistant_text)
189
 
190
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  def render_card(primary: Dict[str, Any]) -> Image.Image:
192
  width, height = 720, 400
193
  canvas = Image.new("RGB", (width, height), "white")
@@ -231,28 +338,6 @@ def medications_to_csv(medications: List[Dict[str, Any]]) -> str:
231
  return ",".join(row)
232
 
233
 
234
- def build_markdown(medications: List[Dict[str, Any]]) -> str:
235
- if not medications:
236
- return "### ์•ฝ ์„ค๋ช…\n- ์•ฝ ์ •๋ณด๋ฅผ ์ธ์‹ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ํ™•์ธํ•ด ์ฃผ์„ธ์š”."
237
-
238
- lines: List[str] = ["### ์‰ฝ๊ฒŒ ์•Œ์•„๋ณด๋Š” ์•ฝ ์„ค๋ช…"]
239
- for med in medications:
240
- lines.append(f"- **{med.get('name') or '์ด๋ฆ„ ๋ฏธํ™•์ธ'}**")
241
- if med.get("description"):
242
- lines.append(f" - ํ•˜๋Š” ์ผ: {med['description']}")
243
- if med.get("usage_example"):
244
- lines.append(f" - ๋ณต์šฉ ์˜ˆ์‹œ: {med['usage_example']}")
245
- if med.get("dosage_example"):
246
- lines.append(f" - ๋ณต์šฉ ๋ฐฉ๋ฒ• ์˜ˆ์‹œ: {med['dosage_example']}")
247
- if med.get("side_effects"):
248
- lines.append(f" - ๋ถ€์ž‘์šฉ/์ฃผ์˜: {med['side_effects']}")
249
- if med.get("warnings"):
250
- lines.append(f" - ์ถ”๊ฐ€ ์ฃผ์˜: {med['warnings']}")
251
-
252
- lines.append("\n> โš ๏ธ ์‹ค์ œ ๋ณต์•ฝ์€ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ์— ๋ฐ˜๋“œ์‹œ ๋”ฐ๋ฅด์„ธ์š”.")
253
- return "\n".join(lines)
254
-
255
-
256
  def format_warnings(warnings: List[str]) -> str:
257
  if not warnings:
258
  return "โœ… ์ธ์‹๋œ ์ •๋ณด๊ฐ€ ์ถฉ๋ถ„ํ•ด์š”. ๋ณต์•ฝ ์‹œ๊ฐ„๋งŒ ์ž˜ ์ง€์ผœ ์ฃผ์„ธ์š”."
@@ -272,6 +357,7 @@ def run_pipeline(image: Optional[Image.Image]):
272
  "์ด๋ฏธ์ง€๋ฅผ ๋จผ์ € ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”.",
273
  "๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ ์‚ฌ์ง„์„ ์˜ฌ๋ฆฌ๋ฉด ์ธ์‹์ด ์‹œ์ž‘๋ผ์š”.",
274
  "",
 
275
  )
276
 
277
  result = analyze_image_with_qwen(image)
@@ -284,14 +370,23 @@ def run_pipeline(image: Optional[Image.Image]):
284
  "time_slots": [],
285
  }
286
 
 
 
287
  card_img = render_card(primary)
288
  csv_row = medications_to_csv(medications)
289
- markdown = build_markdown(medications)
 
 
 
 
 
 
290
  warnings_md = format_warnings(result.get("warnings", []))
291
  raw_text = result.get("raw_text", "")
292
  json_text = json.dumps(result, ensure_ascii=False, indent=2)
 
293
 
294
- return json_text, card_img, csv_row, markdown, warnings_md, raw_text
295
 
296
 
297
  CUSTOM_CSS = """
@@ -319,7 +414,7 @@ body {background: radial-gradient(circle at top left, #f5f0ff 0%, #fff7ec 60%, #
319
  HERO_HTML = """
320
  <div class="hero">
321
  <h1>MedCard-KR ยท ์•ฝ๋ด‰ํˆฌ ํ•œ ์ปท์œผ๋กœ ์ดํ•ดํ•˜๋Š” ๋ณต์šฉ ์•ˆ๋‚ด</h1>
322
- <p>Qwen2.5-VL์ด ์‚ฌ์ง„ ์† ๊ธ€์ž๋ฅผ ์ง์ ‘ ์ฝ๊ณ , ์•ฝ ์„ค๋ช…ยท๋ณต์šฉ ์˜ˆ์‹œยท๋ถ€์ž‘์šฉ๊นŒ์ง€ ํ•œ ๋ฒˆ์— ์ •๋ฆฌํ•ด ๋“œ๋ฆฝ๋‹ˆ๋‹ค.</p>
323
  </div>
324
  """
325
 
@@ -336,6 +431,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
336
  gr.Markdown("### 2. ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜์„ธ์š”")
337
  explain_md = gr.Markdown("์—ฌ๊ธฐ์— ์•ฝ ์„ค๋ช…์ด ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.", elem_classes=["output-card"])
338
  raw_box = gr.Textbox(label="๋ชจ๋ธ์ด ์ฝ์€ ์›๋ฌธ ํ…์ŠคํŠธ", lines=5, interactive=False)
 
339
  card_out = gr.Image(type="pil", label="์ผ์ • ์นด๋“œ(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
340
  csv_box = gr.Textbox(label="CSV(์•ฝ๋ช…,1ํšŒ์šฉ๋Ÿ‰,1์ผํšŸ์ˆ˜,์‹œ๊ฐ„๋Œ€)", lines=2, elem_classes=["csv-box"])
341
  with gr.Accordion("์„ธ๋ถ€ JSON ๊ฒฐ๊ณผ", open=False, elem_classes=["accordion"]):
@@ -344,7 +440,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
344
  btn.click(
345
  run_pipeline,
346
  inputs=img_in,
347
- outputs=[json_out, card_out, csv_box, explain_md, warn_md, raw_box],
348
  )
349
 
350
  gr.Markdown(
 
3
  from typing import Any, Dict, List, Optional
4
 
5
  import gradio as gr
 
6
  import spaces
7
+ import torch
8
+ from diffusers import AutoPipelineForText2Image
9
  from PIL import Image, ImageDraw
10
+ from transformers import (
11
+ AutoModelForCausalLM,
12
+ AutoModelForVision2Seq,
13
+ AutoProcessor,
14
+ AutoTokenizer,
15
+ )
16
 
17
  VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
18
+ TEXT_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
19
+ IMAGE_MODEL_ID = "stabilityai/stable-diffusion-2-1"
20
 
21
 
22
  def _load_vl_model():
 
37
  VL_MODEL, VL_PROCESSOR = _load_vl_model()
38
 
39
 
40
+ def _load_text_model():
41
+ device_map = "auto" if torch.cuda.is_available() else None
42
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
43
+ model = AutoModelForCausalLM.from_pretrained(
44
+ TEXT_MODEL_ID,
45
+ device_map=device_map,
46
+ torch_dtype=dtype,
47
+ trust_remote_code=True,
48
+ )
49
+ if device_map is None:
50
+ model = model.to(torch.device("cpu"))
51
+ tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID, trust_remote_code=True)
52
+ return model, tokenizer
53
+
54
+
55
+ TEXT_MODEL, TEXT_TOKENIZER = _load_text_model()
56
+
57
+
58
+ def _load_image_pipeline():
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
61
+ pipe = AutoPipelineForText2Image.from_pretrained(
62
+ IMAGE_MODEL_ID,
63
+ torch_dtype=dtype,
64
+ safety_checker=None,
65
+ )
66
+ pipe.to(device)
67
+ return pipe
68
+
69
+
70
+ IMAGE_PIPELINE = _load_image_pipeline()
71
+
72
+
73
  def _extract_assistant_content(decoded: str) -> str:
74
  if "<|im_start|>assistant" in decoded:
75
  content = decoded.split("<|im_start|>assistant")[-1]
 
85
  return match.group(0)
86
 
87
 
88
+ def _sanitize_list(value: Any) -> List[str]:
89
+ if isinstance(value, (list, tuple)):
90
+ return [str(v).strip() for v in value if str(v).strip()]
91
+ if isinstance(value, str):
92
+ return [v.strip() for v in re.split(r"[,;]", value) if v.strip()]
93
+ return []
94
 
95
+
96
+ def _sanitize_medication(item: Dict[str, Any]) -> Dict[str, Any]:
97
+ def _to_str(val: Any) -> str:
98
+ return "" if val is None else str(val).strip()
99
 
100
  times = item.get("times_per_day")
101
  if isinstance(times, (int, float)):
102
  times_str = str(int(times)) if float(times).is_integer() else str(times)
103
  else:
104
+ times_str = _to_str(times)
 
 
 
 
 
 
 
 
 
105
 
106
  return {
107
+ "name": _to_str(item.get("name")),
108
+ "dose_per_intake": _to_str(item.get("dose_per_intake")),
109
  "times_per_day": times_str,
110
+ "time_slots": _sanitize_list(item.get("time_slots")),
111
+ "description": _to_str(item.get("description")),
112
+ "usage_example": _to_str(item.get("usage_example")),
113
+ "dosage_example": _to_str(item.get("dosage_example")),
114
+ "side_effects": _to_str(item.get("side_effects")),
115
+ "warnings": _to_str(item.get("warnings")),
116
  }
117
 
118
 
 
122
  return {
123
  "raw_text": "",
124
  "medications": [],
125
+ "warnings": ["๋ชจ๋ธ ์‘๋‹ต์—์„œ JSON ํ˜•์‹์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค."] + ([text.strip()] if text.strip() else []),
126
  }
127
  try:
128
  data = json.loads(json_block)
 
130
  return {
131
  "raw_text": "",
132
  "medications": [],
133
+ "warnings": ["๋ชจ๋ธ JSON ํŒŒ์‹ฑ ์‹คํŒจ", text.strip()],
134
  }
135
 
 
 
136
  meds_raw = data.get("medications") or []
137
  medications: List[Dict[str, Any]] = []
138
  if isinstance(meds_raw, list):
 
149
  warnings = []
150
 
151
  return {
152
+ "raw_text": str(data.get("raw_text", "")).strip(),
153
  "medications": medications,
154
  "warnings": warnings,
155
  }
 
168
  " {\n"
169
  " \"name\": \"์•ฝ ์ด๋ฆ„\",\n"
170
  " \"dose_per_intake\": \"1ํšŒ ์šฉ๋Ÿ‰ (์˜ˆ: 1์ •, 5mL)\",\n"
171
+ " \"times_per_day\": \"ํ•˜๋ฃจ ๋ณต์šฉ ํšŸ์ˆ˜\",\n"
172
  " \"time_slots\": [\"๋ณต์šฉ ์‹œ๊ฐ„๋Œ€\"],\n"
173
+ " \"description\": \"์•ฝ ์„ค๋ช…\",\n"
174
+ " \"usage_example\": \"๋ณต์šฉ ์˜ˆ์‹œ\",\n"
175
+ " \"dosage_example\": \"๋ณต์šฉ ๋ฐฉ๋ฒ• ์˜ˆ์‹œ\",\n"
176
+ " \"side_effects\": \"์ฃผ์š” ๋ถ€์ž‘์šฉ\",\n"
177
+ " \"warnings\": \"์ฃผ์˜ ๋ฌธ๊ตฌ\"\n"
178
  " }\n"
179
  " ],\n"
180
+ " \"warnings\": [\"์ „์ฒด ๊ฒฝ๊ณ \"]\n"
181
  "}"
182
  )
183
  user_prompt = (
184
+ "์œ„ JSON ์Šคํ‚ค๋งˆ๋ฅผ ๋ฐ˜๋“œ์‹œ ๋”ฐ๋ฅด์„ธ์š”. ๋ชจ๋“  ๊ฐ’์€ ํ•œ๊ตญ์–ด๋กœ ์ž‘์„ฑํ•˜๊ณ , ๋นˆ ์ •๋ณด๋Š” ๋นˆ ๋ฌธ์ž์—ด๋กœ ๋‘์„ธ์š”."
 
185
  )
186
 
187
  messages = [
188
  {
189
  "role": "system",
190
+ "content": "๋‹น์‹ ์€ ์•ฝ์‚ฌ ์„ ์ƒ๋‹˜์ž…๋‹ˆ๋‹ค. ์ •ํ™•ํ•˜๊ณ  ์นœ์ ˆํ•˜๊ฒŒ ์ •๋ณด๋ฅผ ์ •๋ฆฌํ•˜์„ธ์š”.",
191
  },
192
  {
193
  "role": "user",
 
201
  ]
202
 
203
  chat_text = VL_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
204
+ inputs = VL_PROCESSOR(text=[chat_text], images=[image], return_tensors="pt").to(VL_MODEL.device)
 
 
 
 
205
 
206
  output_ids = VL_MODEL.generate(
207
  **inputs,
 
216
  return _parse_vl_response(assistant_text)
217
 
218
 
219
+ @spaces.GPU(enable_queue=True)
220
+ def generate_explanations(raw_text: str, medications: List[Dict[str, Any]]) -> Dict[str, str]:
221
+ med_summary_lines = []
222
+ for med in medications:
223
+ summary = f"- {med.get('name', '์ด๋ฆ„ ๋ฏธํ™•์ธ')} {med.get('dose_per_intake', '')}"
224
+ med_summary_lines.append(summary.strip())
225
+ med_summary = "\n".join(med_summary_lines)
226
+
227
+ system_prompt = "์•ฝ์‚ฌ ์„ ์ƒ๋‹˜์ฒ˜๋Ÿผ ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด์—๊ฒŒ ๊ฐ๊ฐ ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•˜์„ธ์š”."
228
+ user_prompt = (
229
+ "๋‹ค์Œ์€ ์•ฝ ๋ด‰ํˆฌ์—์„œ ์ฝ์€ ์›๋ฌธ๊ณผ ์•ฝ ๋ชฉ๋ก์ž…๋‹ˆ๋‹ค. \n"
230
+ "JSON์œผ๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”. ํ˜•์‹์€ {\"elderly\": {\"narrative\": ..., \"image_prompt\": ...}, \"child\": {\"narrative\": ..., \"image_prompt\": ...}} ์ž…๋‹ˆ๋‹ค.\n"
231
+ "narrative๋Š” ํ•œ๊ตญ์–ด, image_prompt๋Š” ์˜์–ด๋กœ ํ•œ ์ปท ๋งŒํ™” ์Šคํƒ€์ผ์„ ๋ฌ˜์‚ฌํ•˜์„ธ์š”.\n"
232
+ f"์•ฝ ๋ชฉ๋ก:\n{med_summary}\n\n์›๋ฌธ:\n{raw_text}\n"
233
+ )
234
+
235
+ messages = [
236
+ {"role": "system", "content": system_prompt},
237
+ {"role": "user", "content": user_prompt},
238
+ ]
239
+
240
+ input_ids = TEXT_TOKENIZER.apply_chat_template(
241
+ messages,
242
+ add_generation_prompt=True,
243
+ return_tensors="pt",
244
+ ).to(TEXT_MODEL.device)
245
+
246
+ with torch.no_grad():
247
+ output_ids = TEXT_MODEL.generate(
248
+ input_ids,
249
+ max_new_tokens=512,
250
+ temperature=0.3,
251
+ top_p=0.8,
252
+ )
253
+
254
+ generated_ids = output_ids[0][input_ids.shape[1]:]
255
+ text = TEXT_TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip()
256
+
257
+ json_block = _extract_json_block(text)
258
+ if not json_block:
259
+ return {
260
+ "elderly_narrative": "์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ๋ฌธ์˜ํ•˜์„ธ์š”.",
261
+ "child_narrative": "์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ๋ฌธ์˜ํ•˜์„ธ์š”.",
262
+ "image_prompt": "single panel cartoon pharmacist helping family, soft colors",
263
+ }
264
+
265
+ try:
266
+ data = json.loads(json_block)
267
+ except json.JSONDecodeError:
268
+ return {
269
+ "elderly_narrative": "์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ๋ฌธ์˜ํ•˜์„ธ์š”.",
270
+ "child_narrative": "์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ๋ฌธ์˜ํ•˜์„ธ์š”.",
271
+ "image_prompt": "single panel cartoon pharmacist helping family, soft colors",
272
+ }
273
+
274
+ elderly = data.get("elderly", {})
275
+ child = data.get("child", {})
276
+
277
+ return {
278
+ "elderly_narrative": str(elderly.get("narrative", "")).strip(),
279
+ "child_narrative": str(child.get("narrative", "")).strip(),
280
+ "image_prompt": str(child.get("image_prompt") or elderly.get("image_prompt") or "single panel cartoon pharmacist helping family, pastel colors").strip(),
281
+ }
282
+
283
+
284
+ @spaces.GPU(enable_queue=True)
285
+ def generate_cartoon_image(prompt: str) -> Image.Image:
286
+ if not prompt:
287
+ prompt = "single panel wholesome cartoon, pharmacist gently explaining medicine to family, warm pastel colors"
288
+ negative_prompt = "text, watermark, logo, blurry"
289
+ image = IMAGE_PIPELINE(
290
+ prompt=prompt,
291
+ negative_prompt=negative_prompt,
292
+ num_inference_steps=30,
293
+ guidance_scale=7.5,
294
+ ).images[0]
295
+ return image
296
+
297
+
298
  def render_card(primary: Dict[str, Any]) -> Image.Image:
299
  width, height = 720, 400
300
  canvas = Image.new("RGB", (width, height), "white")
 
338
  return ",".join(row)
339
 
340
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
341
  def format_warnings(warnings: List[str]) -> str:
342
  if not warnings:
343
  return "โœ… ์ธ์‹๋œ ์ •๋ณด๊ฐ€ ์ถฉ๋ถ„ํ•ด์š”. ๋ณต์•ฝ ์‹œ๊ฐ„๋งŒ ์ž˜ ์ง€์ผœ ์ฃผ์„ธ์š”."
 
357
  "์ด๋ฏธ์ง€๋ฅผ ๋จผ์ € ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”.",
358
  "๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ ์‚ฌ์ง„์„ ์˜ฌ๋ฆฌ๋ฉด ์ธ์‹์ด ์‹œ์ž‘๋ผ์š”.",
359
  "",
360
+ None,
361
  )
362
 
363
  result = analyze_image_with_qwen(image)
 
370
  "time_slots": [],
371
  }
372
 
373
+ narratives = generate_explanations(result.get("raw_text", ""), medications)
374
+
375
  card_img = render_card(primary)
376
  csv_row = medications_to_csv(medications)
377
+ markdown = (
378
+ "## ์–ด๋ฅด์‹ ์„ ์œ„ํ•œ ์„ค๋ช…\n"
379
+ + (narratives.get("elderly_narrative") or "- ์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
380
+ + "\n\n## ์–ด๋ฆฐ์ด๋ฅผ ์œ„ํ•œ ์„ค๋ช…\n"
381
+ + (narratives.get("child_narrative") or "- ์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
382
+ + "\n\n> ํ•ญ์ƒ ์˜๋ฃŒ์ง„์˜ ์•ˆ๋‚ด๋ฅผ ์šฐ์„ ํ•˜์„ธ์š”."
383
+ )
384
  warnings_md = format_warnings(result.get("warnings", []))
385
  raw_text = result.get("raw_text", "")
386
  json_text = json.dumps(result, ensure_ascii=False, indent=2)
387
+ cartoon_image = generate_cartoon_image(narratives.get("image_prompt"))
388
 
389
+ return json_text, card_img, csv_row, markdown, warnings_md, raw_text, cartoon_image
390
 
391
 
392
  CUSTOM_CSS = """
 
414
  HERO_HTML = """
415
  <div class="hero">
416
  <h1>MedCard-KR ยท ์•ฝ๋ด‰ํˆฌ ํ•œ ์ปท์œผ๋กœ ์ดํ•ดํ•˜๋Š” ๋ณต์šฉ ์•ˆ๋‚ด</h1>
417
+ <p>Qwen2.5-VL์ด ์•ฝ ๋ด‰ํˆฌ๋ฅผ ์ง์ ‘ ์ฝ๊ณ , ์•ฝ์‚ฌ์ฒ˜๋Ÿผ ์‰ฝ๊ฒŒ ์„ค๋ช…๊ณผ ํ•œ ์ปท ๋งŒํ™”๋ฅผ ํ•จ๊ป˜ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.</p>
418
  </div>
419
  """
420
 
 
431
  gr.Markdown("### 2. ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜์„ธ์š”")
432
  explain_md = gr.Markdown("์—ฌ๊ธฐ์— ์•ฝ ์„ค๋ช…์ด ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.", elem_classes=["output-card"])
433
  raw_box = gr.Textbox(label="๋ชจ๋ธ์ด ์ฝ์€ ์›๋ฌธ ํ…์ŠคํŠธ", lines=5, interactive=False)
434
+ cartoon_img = gr.Image(type="pil", label="ํ•œ ์ปท ๋งŒํ™”")
435
  card_out = gr.Image(type="pil", label="์ผ์ • ์นด๋“œ(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
436
  csv_box = gr.Textbox(label="CSV(์•ฝ๋ช…,1ํšŒ์šฉ๋Ÿ‰,1์ผํšŸ์ˆ˜,์‹œ๊ฐ„๋Œ€)", lines=2, elem_classes=["csv-box"])
437
  with gr.Accordion("์„ธ๋ถ€ JSON ๊ฒฐ๊ณผ", open=False, elem_classes=["accordion"]):
 
440
  btn.click(
441
  run_pipeline,
442
  inputs=img_in,
443
+ outputs=[json_out, card_out, csv_box, explain_md, warn_md, raw_box, cartoon_img],
444
  )
445
 
446
  gr.Markdown(
requirements.txt CHANGED
@@ -2,7 +2,8 @@ transformers
2
  torch
3
  accelerate
4
  einops
 
 
5
  gradio
6
  Pillow
7
  sentencepiece
8
- torchvision
 
2
  torch
3
  accelerate
4
  einops
5
+ diffusers
6
+ safetensors
7
  gradio
8
  Pillow
9
  sentencepiece