feat: add qwen vl narratives and cartoon generation
Browse files- app.py +169 -73
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -3,12 +3,20 @@ import re
|
|
| 3 |
from typing import Any, Dict, List, Optional
|
| 4 |
|
| 5 |
import gradio as gr
|
| 6 |
-
import torch
|
| 7 |
import spaces
|
|
|
|
|
|
|
| 8 |
from PIL import Image, ImageDraw
|
| 9 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
def _load_vl_model():
|
|
@@ -29,6 +37,39 @@ def _load_vl_model():
|
|
| 29 |
VL_MODEL, VL_PROCESSOR = _load_vl_model()
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def _extract_assistant_content(decoded: str) -> str:
|
| 33 |
if "<|im_start|>assistant" in decoded:
|
| 34 |
content = decoded.split("<|im_start|>assistant")[-1]
|
|
@@ -44,40 +85,34 @@ def _extract_json_block(text: str) -> Optional[str]:
|
|
| 44 |
return match.group(0)
|
| 45 |
|
| 46 |
|
| 47 |
-
def
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
return
|
|
|
|
| 52 |
|
| 53 |
-
|
| 54 |
-
|
|
|
|
|
|
|
| 55 |
|
| 56 |
times = item.get("times_per_day")
|
| 57 |
if isinstance(times, (int, float)):
|
| 58 |
times_str = str(int(times)) if float(times).is_integer() else str(times)
|
| 59 |
else:
|
| 60 |
-
times_str =
|
| 61 |
-
|
| 62 |
-
time_slots_raw = item.get("time_slots")
|
| 63 |
-
if isinstance(time_slots_raw, (list, tuple)):
|
| 64 |
-
time_slots = [str(t).strip() for t in time_slots_raw if str(t).strip()]
|
| 65 |
-
elif isinstance(time_slots_raw, str):
|
| 66 |
-
slots = [s.strip() for s in re.split(r"[,;]\s*", time_slots_raw) if s.strip()]
|
| 67 |
-
time_slots = slots
|
| 68 |
-
else:
|
| 69 |
-
time_slots = []
|
| 70 |
|
| 71 |
return {
|
| 72 |
-
"name": name,
|
| 73 |
-
"dose_per_intake":
|
| 74 |
"times_per_day": times_str,
|
| 75 |
-
"time_slots": time_slots,
|
| 76 |
-
"description":
|
| 77 |
-
"usage_example":
|
| 78 |
-
"dosage_example":
|
| 79 |
-
"side_effects":
|
| 80 |
-
"warnings":
|
| 81 |
}
|
| 82 |
|
| 83 |
|
|
@@ -87,7 +122,7 @@ def _parse_vl_response(text: str) -> Dict[str, Any]:
|
|
| 87 |
return {
|
| 88 |
"raw_text": "",
|
| 89 |
"medications": [],
|
| 90 |
-
"warnings": ["
|
| 91 |
}
|
| 92 |
try:
|
| 93 |
data = json.loads(json_block)
|
|
@@ -95,11 +130,9 @@ def _parse_vl_response(text: str) -> Dict[str, Any]:
|
|
| 95 |
return {
|
| 96 |
"raw_text": "",
|
| 97 |
"medications": [],
|
| 98 |
-
"warnings": ["
|
| 99 |
}
|
| 100 |
|
| 101 |
-
raw_text = str(data.get("raw_text", "")).strip()
|
| 102 |
-
|
| 103 |
meds_raw = data.get("medications") or []
|
| 104 |
medications: List[Dict[str, Any]] = []
|
| 105 |
if isinstance(meds_raw, list):
|
|
@@ -116,7 +149,7 @@ def _parse_vl_response(text: str) -> Dict[str, Any]:
|
|
| 116 |
warnings = []
|
| 117 |
|
| 118 |
return {
|
| 119 |
-
"raw_text": raw_text,
|
| 120 |
"medications": medications,
|
| 121 |
"warnings": warnings,
|
| 122 |
}
|
|
@@ -135,27 +168,26 @@ def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]:
|
|
| 135 |
" {\n"
|
| 136 |
" \"name\": \"์ฝ ์ด๋ฆ\",\n"
|
| 137 |
" \"dose_per_intake\": \"1ํ ์ฉ๋ (์: 1์ , 5mL)\",\n"
|
| 138 |
-
" \"times_per_day\": \"ํ๋ฃจ ๋ณต์ฉ
|
| 139 |
" \"time_slots\": [\"๋ณต์ฉ ์๊ฐ๋\"],\n"
|
| 140 |
-
" \"description\": \"
|
| 141 |
-
" \"usage_example\": \"
|
| 142 |
-
" \"dosage_example\": \"๋ณต์ฉ ๋ฐฉ๋ฒ
|
| 143 |
-
" \"side_effects\": \"์ฃผ์
|
| 144 |
-
" \"warnings\": \"
|
| 145 |
" }\n"
|
| 146 |
" ],\n"
|
| 147 |
-
" \"warnings\": [\"
|
| 148 |
"}"
|
| 149 |
)
|
| 150 |
user_prompt = (
|
| 151 |
-
"์ JSON ์คํค๋ง๋ฅผ
|
| 152 |
-
"๋ชจ๋ ๊ฐ์ ํ๊ตญ์ด๋ก ์์ฑํ๊ณ , ์คํ์๋ ์ดํดํ ์ ์๋ ๋งํฌ๋ก ์ค๋ช
ํ์ธ์."
|
| 153 |
)
|
| 154 |
|
| 155 |
messages = [
|
| 156 |
{
|
| 157 |
"role": "system",
|
| 158 |
-
"content": "๋น์ ์ ์ฝ์ฌ
|
| 159 |
},
|
| 160 |
{
|
| 161 |
"role": "user",
|
|
@@ -169,11 +201,7 @@ def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]:
|
|
| 169 |
]
|
| 170 |
|
| 171 |
chat_text = VL_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
|
| 172 |
-
inputs = VL_PROCESSOR(
|
| 173 |
-
text=[chat_text],
|
| 174 |
-
images=[image],
|
| 175 |
-
return_tensors="pt",
|
| 176 |
-
).to(VL_MODEL.device)
|
| 177 |
|
| 178 |
output_ids = VL_MODEL.generate(
|
| 179 |
**inputs,
|
|
@@ -188,6 +216,85 @@ def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]:
|
|
| 188 |
return _parse_vl_response(assistant_text)
|
| 189 |
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
def render_card(primary: Dict[str, Any]) -> Image.Image:
|
| 192 |
width, height = 720, 400
|
| 193 |
canvas = Image.new("RGB", (width, height), "white")
|
|
@@ -231,28 +338,6 @@ def medications_to_csv(medications: List[Dict[str, Any]]) -> str:
|
|
| 231 |
return ",".join(row)
|
| 232 |
|
| 233 |
|
| 234 |
-
def build_markdown(medications: List[Dict[str, Any]]) -> str:
|
| 235 |
-
if not medications:
|
| 236 |
-
return "### ์ฝ ์ค๋ช
\n- ์ฝ ์ ๋ณด๋ฅผ ์ธ์ํ์ง ๋ชปํ์ต๋๋ค. ์ฝ์ฌ์๊ฒ ์ง์ ํ์ธํด ์ฃผ์ธ์."
|
| 237 |
-
|
| 238 |
-
lines: List[str] = ["### ์ฝ๊ฒ ์์๋ณด๋ ์ฝ ์ค๋ช
"]
|
| 239 |
-
for med in medications:
|
| 240 |
-
lines.append(f"- **{med.get('name') or '์ด๋ฆ ๋ฏธํ์ธ'}**")
|
| 241 |
-
if med.get("description"):
|
| 242 |
-
lines.append(f" - ํ๋ ์ผ: {med['description']}")
|
| 243 |
-
if med.get("usage_example"):
|
| 244 |
-
lines.append(f" - ๋ณต์ฉ ์์: {med['usage_example']}")
|
| 245 |
-
if med.get("dosage_example"):
|
| 246 |
-
lines.append(f" - ๋ณต์ฉ ๋ฐฉ๋ฒ ์์: {med['dosage_example']}")
|
| 247 |
-
if med.get("side_effects"):
|
| 248 |
-
lines.append(f" - ๋ถ์์ฉ/์ฃผ์: {med['side_effects']}")
|
| 249 |
-
if med.get("warnings"):
|
| 250 |
-
lines.append(f" - ์ถ๊ฐ ์ฃผ์: {med['warnings']}")
|
| 251 |
-
|
| 252 |
-
lines.append("\n> โ ๏ธ ์ค์ ๋ณต์ฝ์ ์์ฌยท์ฝ์ฌ์ ์ง์์ ๋ฐ๋์ ๋ฐ๋ฅด์ธ์.")
|
| 253 |
-
return "\n".join(lines)
|
| 254 |
-
|
| 255 |
-
|
| 256 |
def format_warnings(warnings: List[str]) -> str:
|
| 257 |
if not warnings:
|
| 258 |
return "โ
์ธ์๋ ์ ๋ณด๊ฐ ์ถฉ๋ถํด์. ๋ณต์ฝ ์๊ฐ๋ง ์ ์ง์ผ ์ฃผ์ธ์."
|
|
@@ -272,6 +357,7 @@ def run_pipeline(image: Optional[Image.Image]):
|
|
| 272 |
"์ด๋ฏธ์ง๋ฅผ ๋จผ์ ์
๋ก๋ํด ์ฃผ์ธ์.",
|
| 273 |
"๐ท ์ฝ ๋ดํฌ ์ฌ์ง์ ์ฌ๋ฆฌ๋ฉด ์ธ์์ด ์์๋ผ์.",
|
| 274 |
"",
|
|
|
|
| 275 |
)
|
| 276 |
|
| 277 |
result = analyze_image_with_qwen(image)
|
|
@@ -284,14 +370,23 @@ def run_pipeline(image: Optional[Image.Image]):
|
|
| 284 |
"time_slots": [],
|
| 285 |
}
|
| 286 |
|
|
|
|
|
|
|
| 287 |
card_img = render_card(primary)
|
| 288 |
csv_row = medications_to_csv(medications)
|
| 289 |
-
markdown =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
warnings_md = format_warnings(result.get("warnings", []))
|
| 291 |
raw_text = result.get("raw_text", "")
|
| 292 |
json_text = json.dumps(result, ensure_ascii=False, indent=2)
|
|
|
|
| 293 |
|
| 294 |
-
return json_text, card_img, csv_row, markdown, warnings_md, raw_text
|
| 295 |
|
| 296 |
|
| 297 |
CUSTOM_CSS = """
|
|
@@ -319,7 +414,7 @@ body {background: radial-gradient(circle at top left, #f5f0ff 0%, #fff7ec 60%, #
|
|
| 319 |
HERO_HTML = """
|
| 320 |
<div class="hero">
|
| 321 |
<h1>MedCard-KR ยท ์ฝ๋ดํฌ ํ ์ปท์ผ๋ก ์ดํดํ๋ ๋ณต์ฉ ์๋ด</h1>
|
| 322 |
-
<p>Qwen2.5-VL์ด
|
| 323 |
</div>
|
| 324 |
"""
|
| 325 |
|
|
@@ -336,6 +431,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
|
|
| 336 |
gr.Markdown("### 2. ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ์ธ์")
|
| 337 |
explain_md = gr.Markdown("์ฌ๊ธฐ์ ์ฝ ์ค๋ช
์ด ํ์๋ฉ๋๋ค.", elem_classes=["output-card"])
|
| 338 |
raw_box = gr.Textbox(label="๋ชจ๋ธ์ด ์ฝ์ ์๋ฌธ ํ
์คํธ", lines=5, interactive=False)
|
|
|
|
| 339 |
card_out = gr.Image(type="pil", label="์ผ์ ์นด๋(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
|
| 340 |
csv_box = gr.Textbox(label="CSV(์ฝ๋ช
,1ํ์ฉ๋,1์ผํ์,์๊ฐ๋)", lines=2, elem_classes=["csv-box"])
|
| 341 |
with gr.Accordion("์ธ๋ถ JSON ๊ฒฐ๊ณผ", open=False, elem_classes=["accordion"]):
|
|
@@ -344,7 +440,7 @@ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
|
|
| 344 |
btn.click(
|
| 345 |
run_pipeline,
|
| 346 |
inputs=img_in,
|
| 347 |
-
outputs=[json_out, card_out, csv_box, explain_md, warn_md, raw_box],
|
| 348 |
)
|
| 349 |
|
| 350 |
gr.Markdown(
|
|
|
|
| 3 |
from typing import Any, Dict, List, Optional
|
| 4 |
|
| 5 |
import gradio as gr
|
|
|
|
| 6 |
import spaces
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers import AutoPipelineForText2Image
|
| 9 |
from PIL import Image, ImageDraw
|
| 10 |
+
from transformers import (
|
| 11 |
+
AutoModelForCausalLM,
|
| 12 |
+
AutoModelForVision2Seq,
|
| 13 |
+
AutoProcessor,
|
| 14 |
+
AutoTokenizer,
|
| 15 |
+
)
|
| 16 |
|
| 17 |
VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
|
| 18 |
+
TEXT_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
|
| 19 |
+
IMAGE_MODEL_ID = "stabilityai/stable-diffusion-2-1"
|
| 20 |
|
| 21 |
|
| 22 |
def _load_vl_model():
|
|
|
|
| 37 |
VL_MODEL, VL_PROCESSOR = _load_vl_model()
|
| 38 |
|
| 39 |
|
| 40 |
+
def _load_text_model():
|
| 41 |
+
device_map = "auto" if torch.cuda.is_available() else None
|
| 42 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 43 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 44 |
+
TEXT_MODEL_ID,
|
| 45 |
+
device_map=device_map,
|
| 46 |
+
torch_dtype=dtype,
|
| 47 |
+
trust_remote_code=True,
|
| 48 |
+
)
|
| 49 |
+
if device_map is None:
|
| 50 |
+
model = model.to(torch.device("cpu"))
|
| 51 |
+
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID, trust_remote_code=True)
|
| 52 |
+
return model, tokenizer
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
TEXT_MODEL, TEXT_TOKENIZER = _load_text_model()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _load_image_pipeline():
|
| 59 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 60 |
+
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 61 |
+
pipe = AutoPipelineForText2Image.from_pretrained(
|
| 62 |
+
IMAGE_MODEL_ID,
|
| 63 |
+
torch_dtype=dtype,
|
| 64 |
+
safety_checker=None,
|
| 65 |
+
)
|
| 66 |
+
pipe.to(device)
|
| 67 |
+
return pipe
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
IMAGE_PIPELINE = _load_image_pipeline()
|
| 71 |
+
|
| 72 |
+
|
| 73 |
def _extract_assistant_content(decoded: str) -> str:
|
| 74 |
if "<|im_start|>assistant" in decoded:
|
| 75 |
content = decoded.split("<|im_start|>assistant")[-1]
|
|
|
|
| 85 |
return match.group(0)
|
| 86 |
|
| 87 |
|
| 88 |
+
def _sanitize_list(value: Any) -> List[str]:
|
| 89 |
+
if isinstance(value, (list, tuple)):
|
| 90 |
+
return [str(v).strip() for v in value if str(v).strip()]
|
| 91 |
+
if isinstance(value, str):
|
| 92 |
+
return [v.strip() for v in re.split(r"[,;]", value) if v.strip()]
|
| 93 |
+
return []
|
| 94 |
|
| 95 |
+
|
| 96 |
+
def _sanitize_medication(item: Dict[str, Any]) -> Dict[str, Any]:
|
| 97 |
+
def _to_str(val: Any) -> str:
|
| 98 |
+
return "" if val is None else str(val).strip()
|
| 99 |
|
| 100 |
times = item.get("times_per_day")
|
| 101 |
if isinstance(times, (int, float)):
|
| 102 |
times_str = str(int(times)) if float(times).is_integer() else str(times)
|
| 103 |
else:
|
| 104 |
+
times_str = _to_str(times)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
|
| 106 |
return {
|
| 107 |
+
"name": _to_str(item.get("name")),
|
| 108 |
+
"dose_per_intake": _to_str(item.get("dose_per_intake")),
|
| 109 |
"times_per_day": times_str,
|
| 110 |
+
"time_slots": _sanitize_list(item.get("time_slots")),
|
| 111 |
+
"description": _to_str(item.get("description")),
|
| 112 |
+
"usage_example": _to_str(item.get("usage_example")),
|
| 113 |
+
"dosage_example": _to_str(item.get("dosage_example")),
|
| 114 |
+
"side_effects": _to_str(item.get("side_effects")),
|
| 115 |
+
"warnings": _to_str(item.get("warnings")),
|
| 116 |
}
|
| 117 |
|
| 118 |
|
|
|
|
| 122 |
return {
|
| 123 |
"raw_text": "",
|
| 124 |
"medications": [],
|
| 125 |
+
"warnings": ["๋ชจ๋ธ ์๋ต์์ JSON ํ์์ ์ฐพ์ง ๋ชปํ์ต๋๋ค."] + ([text.strip()] if text.strip() else []),
|
| 126 |
}
|
| 127 |
try:
|
| 128 |
data = json.loads(json_block)
|
|
|
|
| 130 |
return {
|
| 131 |
"raw_text": "",
|
| 132 |
"medications": [],
|
| 133 |
+
"warnings": ["๋ชจ๋ธ JSON ํ์ฑ ์คํจ", text.strip()],
|
| 134 |
}
|
| 135 |
|
|
|
|
|
|
|
| 136 |
meds_raw = data.get("medications") or []
|
| 137 |
medications: List[Dict[str, Any]] = []
|
| 138 |
if isinstance(meds_raw, list):
|
|
|
|
| 149 |
warnings = []
|
| 150 |
|
| 151 |
return {
|
| 152 |
+
"raw_text": str(data.get("raw_text", "")).strip(),
|
| 153 |
"medications": medications,
|
| 154 |
"warnings": warnings,
|
| 155 |
}
|
|
|
|
| 168 |
" {\n"
|
| 169 |
" \"name\": \"์ฝ ์ด๋ฆ\",\n"
|
| 170 |
" \"dose_per_intake\": \"1ํ ์ฉ๋ (์: 1์ , 5mL)\",\n"
|
| 171 |
+
" \"times_per_day\": \"ํ๋ฃจ ๋ณต์ฉ ํ์\",\n"
|
| 172 |
" \"time_slots\": [\"๋ณต์ฉ ์๊ฐ๋\"],\n"
|
| 173 |
+
" \"description\": \"์ฝ ์ค๋ช
\",\n"
|
| 174 |
+
" \"usage_example\": \"๋ณต์ฉ ์์\",\n"
|
| 175 |
+
" \"dosage_example\": \"๋ณต์ฉ ๋ฐฉ๋ฒ ์์\",\n"
|
| 176 |
+
" \"side_effects\": \"์ฃผ์ ๋ถ์์ฉ\",\n"
|
| 177 |
+
" \"warnings\": \"์ฃผ์ ๋ฌธ๊ตฌ\"\n"
|
| 178 |
" }\n"
|
| 179 |
" ],\n"
|
| 180 |
+
" \"warnings\": [\"์ ์ฒด ๊ฒฝ๊ณ \"]\n"
|
| 181 |
"}"
|
| 182 |
)
|
| 183 |
user_prompt = (
|
| 184 |
+
"์ JSON ์คํค๋ง๋ฅผ ๋ฐ๋์ ๋ฐ๋ฅด์ธ์. ๋ชจ๋ ๊ฐ์ ํ๊ตญ์ด๋ก ์์ฑํ๊ณ , ๋น ์ ๋ณด๋ ๋น ๋ฌธ์์ด๋ก ๋์ธ์."
|
|
|
|
| 185 |
)
|
| 186 |
|
| 187 |
messages = [
|
| 188 |
{
|
| 189 |
"role": "system",
|
| 190 |
+
"content": "๋น์ ์ ์ฝ์ฌ ์ ์๋์
๋๋ค. ์ ํํ๊ณ ์น์ ํ๊ฒ ์ ๋ณด๋ฅผ ์ ๋ฆฌํ์ธ์.",
|
| 191 |
},
|
| 192 |
{
|
| 193 |
"role": "user",
|
|
|
|
| 201 |
]
|
| 202 |
|
| 203 |
chat_text = VL_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
|
| 204 |
+
inputs = VL_PROCESSOR(text=[chat_text], images=[image], return_tensors="pt").to(VL_MODEL.device)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 205 |
|
| 206 |
output_ids = VL_MODEL.generate(
|
| 207 |
**inputs,
|
|
|
|
| 216 |
return _parse_vl_response(assistant_text)
|
| 217 |
|
| 218 |
|
| 219 |
+
@spaces.GPU(enable_queue=True)
|
| 220 |
+
def generate_explanations(raw_text: str, medications: List[Dict[str, Any]]) -> Dict[str, str]:
|
| 221 |
+
med_summary_lines = []
|
| 222 |
+
for med in medications:
|
| 223 |
+
summary = f"- {med.get('name', '์ด๋ฆ ๋ฏธํ์ธ')} {med.get('dose_per_intake', '')}"
|
| 224 |
+
med_summary_lines.append(summary.strip())
|
| 225 |
+
med_summary = "\n".join(med_summary_lines)
|
| 226 |
+
|
| 227 |
+
system_prompt = "์ฝ์ฌ ์ ์๋์ฒ๋ผ ์ด๋ฅด์ ๊ณผ ์ด๋ฆฐ์ด์๊ฒ ๊ฐ๊ฐ ์ฝ๊ฒ ์ค๋ช
ํ์ธ์."
|
| 228 |
+
user_prompt = (
|
| 229 |
+
"๋ค์์ ์ฝ ๋ดํฌ์์ ์ฝ์ ์๋ฌธ๊ณผ ์ฝ ๋ชฉ๋ก์
๋๋ค. \n"
|
| 230 |
+
"JSON์ผ๋ก ๋ต๋ณํ์ธ์. ํ์์ {\"elderly\": {\"narrative\": ..., \"image_prompt\": ...}, \"child\": {\"narrative\": ..., \"image_prompt\": ...}} ์
๋๋ค.\n"
|
| 231 |
+
"narrative๋ ํ๊ตญ์ด, image_prompt๋ ์์ด๋ก ํ ์ปท ๋งํ ์คํ์ผ์ ๋ฌ์ฌํ์ธ์.\n"
|
| 232 |
+
f"์ฝ ๋ชฉ๋ก:\n{med_summary}\n\n์๋ฌธ:\n{raw_text}\n"
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
messages = [
|
| 236 |
+
{"role": "system", "content": system_prompt},
|
| 237 |
+
{"role": "user", "content": user_prompt},
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
input_ids = TEXT_TOKENIZER.apply_chat_template(
|
| 241 |
+
messages,
|
| 242 |
+
add_generation_prompt=True,
|
| 243 |
+
return_tensors="pt",
|
| 244 |
+
).to(TEXT_MODEL.device)
|
| 245 |
+
|
| 246 |
+
with torch.no_grad():
|
| 247 |
+
output_ids = TEXT_MODEL.generate(
|
| 248 |
+
input_ids,
|
| 249 |
+
max_new_tokens=512,
|
| 250 |
+
temperature=0.3,
|
| 251 |
+
top_p=0.8,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
generated_ids = output_ids[0][input_ids.shape[1]:]
|
| 255 |
+
text = TEXT_TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip()
|
| 256 |
+
|
| 257 |
+
json_block = _extract_json_block(text)
|
| 258 |
+
if not json_block:
|
| 259 |
+
return {
|
| 260 |
+
"elderly_narrative": "์ค๋ช
์ ์ค๋นํ์ง ๋ชปํ์ต๋๋ค. ์ฝ์ฌ์๊ฒ ์ง์ ๋ฌธ์ํ์ธ์.",
|
| 261 |
+
"child_narrative": "์ค๋ช
์ ์ค๋นํ์ง ๋ชปํ์ต๋๋ค. ์ฝ์ฌ์๊ฒ ์ง์ ๋ฌธ์ํ์ธ์.",
|
| 262 |
+
"image_prompt": "single panel cartoon pharmacist helping family, soft colors",
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
try:
|
| 266 |
+
data = json.loads(json_block)
|
| 267 |
+
except json.JSONDecodeError:
|
| 268 |
+
return {
|
| 269 |
+
"elderly_narrative": "์ค๋ช
์ ์ค๋นํ์ง ๋ชปํ์ต๋๋ค. ์ฝ์ฌ์๊ฒ ์ง์ ๋ฌธ์ํ์ธ์.",
|
| 270 |
+
"child_narrative": "์ค๋ช
์ ์ค๋นํ์ง ๋ชปํ์ต๋๋ค. ์ฝ์ฌ์๊ฒ ์ง์ ๋ฌธ์ํ์ธ์.",
|
| 271 |
+
"image_prompt": "single panel cartoon pharmacist helping family, soft colors",
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
elderly = data.get("elderly", {})
|
| 275 |
+
child = data.get("child", {})
|
| 276 |
+
|
| 277 |
+
return {
|
| 278 |
+
"elderly_narrative": str(elderly.get("narrative", "")).strip(),
|
| 279 |
+
"child_narrative": str(child.get("narrative", "")).strip(),
|
| 280 |
+
"image_prompt": str(child.get("image_prompt") or elderly.get("image_prompt") or "single panel cartoon pharmacist helping family, pastel colors").strip(),
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
@spaces.GPU(enable_queue=True)
|
| 285 |
+
def generate_cartoon_image(prompt: str) -> Image.Image:
|
| 286 |
+
if not prompt:
|
| 287 |
+
prompt = "single panel wholesome cartoon, pharmacist gently explaining medicine to family, warm pastel colors"
|
| 288 |
+
negative_prompt = "text, watermark, logo, blurry"
|
| 289 |
+
image = IMAGE_PIPELINE(
|
| 290 |
+
prompt=prompt,
|
| 291 |
+
negative_prompt=negative_prompt,
|
| 292 |
+
num_inference_steps=30,
|
| 293 |
+
guidance_scale=7.5,
|
| 294 |
+
).images[0]
|
| 295 |
+
return image
|
| 296 |
+
|
| 297 |
+
|
| 298 |
def render_card(primary: Dict[str, Any]) -> Image.Image:
|
| 299 |
width, height = 720, 400
|
| 300 |
canvas = Image.new("RGB", (width, height), "white")
|
|
|
|
| 338 |
return ",".join(row)
|
| 339 |
|
| 340 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
def format_warnings(warnings: List[str]) -> str:
|
| 342 |
if not warnings:
|
| 343 |
return "โ
์ธ์๋ ์ ๋ณด๊ฐ ์ถฉ๋ถํด์. ๋ณต์ฝ ์๊ฐ๋ง ์ ์ง์ผ ์ฃผ์ธ์."
|
|
|
|
| 357 |
"์ด๋ฏธ์ง๋ฅผ ๋จผ์ ์
๋ก๋ํด ์ฃผ์ธ์.",
|
| 358 |
"๐ท ์ฝ ๋ดํฌ ์ฌ์ง์ ์ฌ๋ฆฌ๋ฉด ์ธ์์ด ์์๋ผ์.",
|
| 359 |
"",
|
| 360 |
+
None,
|
| 361 |
)
|
| 362 |
|
| 363 |
result = analyze_image_with_qwen(image)
|
|
|
|
| 370 |
"time_slots": [],
|
| 371 |
}
|
| 372 |
|
| 373 |
+
narratives = generate_explanations(result.get("raw_text", ""), medications)
|
| 374 |
+
|
| 375 |
card_img = render_card(primary)
|
| 376 |
csv_row = medications_to_csv(medications)
|
| 377 |
+
markdown = (
|
| 378 |
+
"## ์ด๋ฅด์ ์ ์ํ ์ค๋ช
\n"
|
| 379 |
+
+ (narratives.get("elderly_narrative") or "- ์ค๋ช
์ ์ค๋นํ์ง ๋ชปํ์ต๋๋ค.")
|
| 380 |
+
+ "\n\n## ์ด๋ฆฐ์ด๋ฅผ ์ํ ์ค๋ช
\n"
|
| 381 |
+
+ (narratives.get("child_narrative") or "- ์ค๋ช
์ ์ค๋นํ์ง ๋ชปํ์ต๋๋ค.")
|
| 382 |
+
+ "\n\n> ํญ์ ์๋ฃ์ง์ ์๋ด๋ฅผ ์ฐ์ ํ์ธ์."
|
| 383 |
+
)
|
| 384 |
warnings_md = format_warnings(result.get("warnings", []))
|
| 385 |
raw_text = result.get("raw_text", "")
|
| 386 |
json_text = json.dumps(result, ensure_ascii=False, indent=2)
|
| 387 |
+
cartoon_image = generate_cartoon_image(narratives.get("image_prompt"))
|
| 388 |
|
| 389 |
+
return json_text, card_img, csv_row, markdown, warnings_md, raw_text, cartoon_image
|
| 390 |
|
| 391 |
|
| 392 |
CUSTOM_CSS = """
|
|
|
|
| 414 |
HERO_HTML = """
|
| 415 |
<div class="hero">
|
| 416 |
<h1>MedCard-KR ยท ์ฝ๋ดํฌ ํ ์ปท์ผ๋ก ์ดํดํ๋ ๋ณต์ฉ ์๋ด</h1>
|
| 417 |
+
<p>Qwen2.5-VL์ด ์ฝ ๋ดํฌ๋ฅผ ์ง์ ์ฝ๊ณ , ์ฝ์ฌ์ฒ๋ผ ์ฝ๊ฒ ์ค๋ช
๊ณผ ํ ์ปท ๋งํ๋ฅผ ํจ๊ป ์ ๊ณตํฉ๋๋ค.</p>
|
| 418 |
</div>
|
| 419 |
"""
|
| 420 |
|
|
|
|
| 431 |
gr.Markdown("### 2. ๊ฒฐ๊ณผ๋ฅผ ํ์ธํ์ธ์")
|
| 432 |
explain_md = gr.Markdown("์ฌ๊ธฐ์ ์ฝ ์ค๋ช
์ด ํ์๋ฉ๋๋ค.", elem_classes=["output-card"])
|
| 433 |
raw_box = gr.Textbox(label="๋ชจ๋ธ์ด ์ฝ์ ์๋ฌธ ํ
์คํธ", lines=5, interactive=False)
|
| 434 |
+
cartoon_img = gr.Image(type="pil", label="ํ ์ปท ๋งํ")
|
| 435 |
card_out = gr.Image(type="pil", label="์ผ์ ์นด๋(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
|
| 436 |
csv_box = gr.Textbox(label="CSV(์ฝ๋ช
,1ํ์ฉ๋,1์ผํ์,์๊ฐ๋)", lines=2, elem_classes=["csv-box"])
|
| 437 |
with gr.Accordion("์ธ๋ถ JSON ๊ฒฐ๊ณผ", open=False, elem_classes=["accordion"]):
|
|
|
|
| 440 |
btn.click(
|
| 441 |
run_pipeline,
|
| 442 |
inputs=img_in,
|
| 443 |
+
outputs=[json_out, card_out, csv_box, explain_md, warn_md, raw_box, cartoon_img],
|
| 444 |
)
|
| 445 |
|
| 446 |
gr.Markdown(
|
requirements.txt
CHANGED
|
@@ -2,7 +2,8 @@ transformers
|
|
| 2 |
torch
|
| 3 |
accelerate
|
| 4 |
einops
|
|
|
|
|
|
|
| 5 |
gradio
|
| 6 |
Pillow
|
| 7 |
sentencepiece
|
| 8 |
-
torchvision
|
|
|
|
| 2 |
torch
|
| 3 |
accelerate
|
| 4 |
einops
|
| 5 |
+
diffusers
|
| 6 |
+
safetensors
|
| 7 |
gradio
|
| 8 |
Pillow
|
| 9 |
sentencepiece
|
|
|