MedCard / app.py
LLDDWW's picture
feat: add qwen vl narratives and cartoon generation
e53f54d
raw
history blame
17 kB
import json
import re
from typing import Any, Dict, List, Optional
import gradio as gr
import spaces
import torch
from diffusers import AutoPipelineForText2Image
from PIL import Image, ImageDraw
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
)
VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
TEXT_MODEL_ID = "Qwen/Qwen2.5-1.5B-Instruct"
IMAGE_MODEL_ID = "stabilityai/stable-diffusion-2-1"
def _load_vl_model():
device_map = "auto" if torch.cuda.is_available() else None
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForVision2Seq.from_pretrained(
VL_MODEL_ID,
device_map=device_map,
torch_dtype=dtype,
trust_remote_code=True,
)
if device_map is None:
model = model.to(torch.device("cpu"))
processor = AutoProcessor.from_pretrained(VL_MODEL_ID, trust_remote_code=True)
return model, processor
VL_MODEL, VL_PROCESSOR = _load_vl_model()
def _load_text_model():
device_map = "auto" if torch.cuda.is_available() else None
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
TEXT_MODEL_ID,
device_map=device_map,
torch_dtype=dtype,
trust_remote_code=True,
)
if device_map is None:
model = model.to(torch.device("cpu"))
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID, trust_remote_code=True)
return model, tokenizer
TEXT_MODEL, TEXT_TOKENIZER = _load_text_model()
def _load_image_pipeline():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(
IMAGE_MODEL_ID,
torch_dtype=dtype,
safety_checker=None,
)
pipe.to(device)
return pipe
IMAGE_PIPELINE = _load_image_pipeline()
def _extract_assistant_content(decoded: str) -> str:
if "<|im_start|>assistant" in decoded:
content = decoded.split("<|im_start|>assistant")[-1]
content = content.replace("<|im_end|>", "").strip()
return content
return decoded.strip()
def _extract_json_block(text: str) -> Optional[str]:
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
return None
return match.group(0)
def _sanitize_list(value: Any) -> List[str]:
if isinstance(value, (list, tuple)):
return [str(v).strip() for v in value if str(v).strip()]
if isinstance(value, str):
return [v.strip() for v in re.split(r"[,;]", value) if v.strip()]
return []
def _sanitize_medication(item: Dict[str, Any]) -> Dict[str, Any]:
def _to_str(val: Any) -> str:
return "" if val is None else str(val).strip()
times = item.get("times_per_day")
if isinstance(times, (int, float)):
times_str = str(int(times)) if float(times).is_integer() else str(times)
else:
times_str = _to_str(times)
return {
"name": _to_str(item.get("name")),
"dose_per_intake": _to_str(item.get("dose_per_intake")),
"times_per_day": times_str,
"time_slots": _sanitize_list(item.get("time_slots")),
"description": _to_str(item.get("description")),
"usage_example": _to_str(item.get("usage_example")),
"dosage_example": _to_str(item.get("dosage_example")),
"side_effects": _to_str(item.get("side_effects")),
"warnings": _to_str(item.get("warnings")),
}
def _parse_vl_response(text: str) -> Dict[str, Any]:
json_block = _extract_json_block(text)
if not json_block:
return {
"raw_text": "",
"medications": [],
"warnings": ["๋ชจ๋ธ ์‘๋‹ต์—์„œ JSON ํ˜•์‹์„ ์ฐพ์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค."] + ([text.strip()] if text.strip() else []),
}
try:
data = json.loads(json_block)
except json.JSONDecodeError:
return {
"raw_text": "",
"medications": [],
"warnings": ["๋ชจ๋ธ JSON ํŒŒ์‹ฑ ์‹คํŒจ", text.strip()],
}
meds_raw = data.get("medications") or []
medications: List[Dict[str, Any]] = []
if isinstance(meds_raw, list):
for item in meds_raw:
if isinstance(item, dict):
medications.append(_sanitize_medication(item))
warnings_raw = data.get("warnings")
if isinstance(warnings_raw, list):
warnings = [str(w).strip() for w in warnings_raw if str(w).strip()]
elif warnings_raw:
warnings = [str(warnings_raw).strip()]
else:
warnings = []
return {
"raw_text": str(data.get("raw_text", "")).strip(),
"medications": medications,
"warnings": warnings,
}
@spaces.GPU(enable_queue=True)
def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]:
instructions = (
"์‚ฌ์ง„ ์† ์•ฝ๋ด‰ํˆฌ/์ฒ˜๋ฐฉ์ „์„ ์ฝ๊ณ  ์•„๋ž˜ JSON ํ˜•์‹์œผ๋กœ๋งŒ ๋‹ต๋ณ€ํ•˜์„ธ์š”. "
"ํ…์ŠคํŠธ ์™ธ์˜ ์„ค๋ช…์ด๋‚˜ ์ถ”๊ฐ€ ๋ฌธ์žฅ์€ ์ ˆ๋Œ€ ๋„ฃ์ง€ ๋งˆ์„ธ์š”."
)
schema = (
"{\n"
" \"raw_text\": \"OCR๋กœ ์ฝ์€ ์ „์ฒด ๋ฌธ์žฅ\",\n"
" \"medications\": [\n"
" {\n"
" \"name\": \"์•ฝ ์ด๋ฆ„\",\n"
" \"dose_per_intake\": \"1ํšŒ ์šฉ๋Ÿ‰ (์˜ˆ: 1์ •, 5mL)\",\n"
" \"times_per_day\": \"ํ•˜๋ฃจ ๋ณต์šฉ ํšŸ์ˆ˜\",\n"
" \"time_slots\": [\"๋ณต์šฉ ์‹œ๊ฐ„๋Œ€\"],\n"
" \"description\": \"์•ฝ ์„ค๋ช…\",\n"
" \"usage_example\": \"๋ณต์šฉ ์˜ˆ์‹œ\",\n"
" \"dosage_example\": \"๋ณต์šฉ ๋ฐฉ๋ฒ• ์˜ˆ์‹œ\",\n"
" \"side_effects\": \"์ฃผ์š” ๋ถ€์ž‘์šฉ\",\n"
" \"warnings\": \"์ฃผ์˜ ๋ฌธ๊ตฌ\"\n"
" }\n"
" ],\n"
" \"warnings\": [\"์ „์ฒด ๊ฒฝ๊ณ \"]\n"
"}"
)
user_prompt = (
"์œ„ JSON ์Šคํ‚ค๋งˆ๋ฅผ ๋ฐ˜๋“œ์‹œ ๋”ฐ๋ฅด์„ธ์š”. ๋ชจ๋“  ๊ฐ’์€ ํ•œ๊ตญ์–ด๋กœ ์ž‘์„ฑํ•˜๊ณ , ๋นˆ ์ •๋ณด๋Š” ๋นˆ ๋ฌธ์ž์—ด๋กœ ๋‘์„ธ์š”."
)
messages = [
{
"role": "system",
"content": "๋‹น์‹ ์€ ์•ฝ์‚ฌ ์„ ์ƒ๋‹˜์ž…๋‹ˆ๋‹ค. ์ •ํ™•ํ•˜๊ณ  ์นœ์ ˆํ•˜๊ฒŒ ์ •๋ณด๋ฅผ ์ •๋ฆฌํ•˜์„ธ์š”.",
},
{
"role": "user",
"content": [
{"type": "text", "text": instructions},
{"type": "text", "text": schema},
{"type": "text", "text": user_prompt},
{"type": "image"},
],
},
]
chat_text = VL_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
inputs = VL_PROCESSOR(text=[chat_text], images=[image], return_tensors="pt").to(VL_MODEL.device)
output_ids = VL_MODEL.generate(
**inputs,
max_new_tokens=1024,
temperature=0.1,
top_p=0.9,
do_sample=False,
)
decoded = VL_PROCESSOR.batch_decode(output_ids, skip_special_tokens=False)[0]
assistant_text = _extract_assistant_content(decoded)
return _parse_vl_response(assistant_text)
@spaces.GPU(enable_queue=True)
def generate_explanations(raw_text: str, medications: List[Dict[str, Any]]) -> Dict[str, str]:
med_summary_lines = []
for med in medications:
summary = f"- {med.get('name', '์ด๋ฆ„ ๋ฏธํ™•์ธ')} {med.get('dose_per_intake', '')}"
med_summary_lines.append(summary.strip())
med_summary = "\n".join(med_summary_lines)
system_prompt = "์•ฝ์‚ฌ ์„ ์ƒ๋‹˜์ฒ˜๋Ÿผ ์–ด๋ฅด์‹ ๊ณผ ์–ด๋ฆฐ์ด์—๊ฒŒ ๊ฐ๊ฐ ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•˜์„ธ์š”."
user_prompt = (
"๋‹ค์Œ์€ ์•ฝ ๋ด‰ํˆฌ์—์„œ ์ฝ์€ ์›๋ฌธ๊ณผ ์•ฝ ๋ชฉ๋ก์ž…๋‹ˆ๋‹ค. \n"
"JSON์œผ๋กœ ๋‹ต๋ณ€ํ•˜์„ธ์š”. ํ˜•์‹์€ {\"elderly\": {\"narrative\": ..., \"image_prompt\": ...}, \"child\": {\"narrative\": ..., \"image_prompt\": ...}} ์ž…๋‹ˆ๋‹ค.\n"
"narrative๋Š” ํ•œ๊ตญ์–ด, image_prompt๋Š” ์˜์–ด๋กœ ํ•œ ์ปท ๋งŒํ™” ์Šคํƒ€์ผ์„ ๋ฌ˜์‚ฌํ•˜์„ธ์š”.\n"
f"์•ฝ ๋ชฉ๋ก:\n{med_summary}\n\n์›๋ฌธ:\n{raw_text}\n"
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
input_ids = TEXT_TOKENIZER.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(TEXT_MODEL.device)
with torch.no_grad():
output_ids = TEXT_MODEL.generate(
input_ids,
max_new_tokens=512,
temperature=0.3,
top_p=0.8,
)
generated_ids = output_ids[0][input_ids.shape[1]:]
text = TEXT_TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip()
json_block = _extract_json_block(text)
if not json_block:
return {
"elderly_narrative": "์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ๋ฌธ์˜ํ•˜์„ธ์š”.",
"child_narrative": "์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ๋ฌธ์˜ํ•˜์„ธ์š”.",
"image_prompt": "single panel cartoon pharmacist helping family, soft colors",
}
try:
data = json.loads(json_block)
except json.JSONDecodeError:
return {
"elderly_narrative": "์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ๋ฌธ์˜ํ•˜์„ธ์š”.",
"child_narrative": "์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค. ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ๋ฌธ์˜ํ•˜์„ธ์š”.",
"image_prompt": "single panel cartoon pharmacist helping family, soft colors",
}
elderly = data.get("elderly", {})
child = data.get("child", {})
return {
"elderly_narrative": str(elderly.get("narrative", "")).strip(),
"child_narrative": str(child.get("narrative", "")).strip(),
"image_prompt": str(child.get("image_prompt") or elderly.get("image_prompt") or "single panel cartoon pharmacist helping family, pastel colors").strip(),
}
@spaces.GPU(enable_queue=True)
def generate_cartoon_image(prompt: str) -> Image.Image:
if not prompt:
prompt = "single panel wholesome cartoon, pharmacist gently explaining medicine to family, warm pastel colors"
negative_prompt = "text, watermark, logo, blurry"
image = IMAGE_PIPELINE(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=30,
guidance_scale=7.5,
).images[0]
return image
def render_card(primary: Dict[str, Any]) -> Image.Image:
width, height = 720, 400
canvas = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(canvas)
header = "์˜ค๋Š˜ ๋ณต์šฉ ์ผ์ •"
draw.rectangle((0, 0, width, 60), fill=(230, 240, 255))
draw.text((24, 18), header, fill=(0, 0, 0))
y = 90
def add_line(label: str, value: Optional[str]):
nonlocal y
text_value = value if value else "-"
draw.text((24, y), label, fill=(60, 60, 60))
draw.text((200, y), f": {text_value}", fill=(0, 0, 0))
y += 34
add_line("์•ฝ ์ด๋ฆ„", primary.get("name"))
add_line("1ํšŒ ์šฉ๋Ÿ‰", primary.get("dose_per_intake"))
add_line("1์ผ ํšŸ์ˆ˜", primary.get("times_per_day"))
slots = primary.get("time_slots") or []
add_line("์‹œ๊ฐ„๋Œ€", ", ".join(slots) if slots else None)
footer = "โ€ป ์˜๋ฃŒ์ง„ ์ฒ˜๋ฐฉ์ด ์šฐ์„ ์ด๋ฉฐ, ๋ณธ ์•ฑ์€ ์•ˆ๋‚ด์šฉ์ž…๋‹ˆ๋‹ค."
draw.text((24, height - 60), footer, fill=(120, 120, 120))
return canvas
def medications_to_csv(medications: List[Dict[str, Any]]) -> str:
if not medications:
return ""
first = medications[0]
row = [
first.get("name", ""),
first.get("dose_per_intake", ""),
first.get("times_per_day", ""),
";".join(first.get("time_slots") or []),
]
return ",".join(row)
def format_warnings(warnings: List[str]) -> str:
if not warnings:
return "โœ… ์ธ์‹๋œ ์ •๋ณด๊ฐ€ ์ถฉ๋ถ„ํ•ด์š”. ๋ณต์•ฝ ์‹œ๊ฐ„๋งŒ ์ž˜ ์ง€์ผœ ์ฃผ์„ธ์š”."
lines = ["### ํ™•์ธํ•ด ์ฃผ์„ธ์š”"]
for warn in warnings:
lines.append(f"- {warn}")
lines.append("\n> ์˜๋ฃŒ์ง„์˜ ์ง€์‹œ๊ฐ€ ๊ฐ€์žฅ ์ •ํ™•ํ•ฉ๋‹ˆ๋‹ค.")
return "\n".join(lines)
def run_pipeline(image: Optional[Image.Image]):
if image is None:
return (
"์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.",
None,
None,
"์ด๋ฏธ์ง€๋ฅผ ๋จผ์ € ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”.",
"๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ ์‚ฌ์ง„์„ ์˜ฌ๋ฆฌ๋ฉด ์ธ์‹์ด ์‹œ์ž‘๋ผ์š”.",
"",
None,
)
result = analyze_image_with_qwen(image)
medications = result.get("medications") or []
primary = medications[0] if medications else {
"name": "",
"dose_per_intake": "",
"times_per_day": "",
"time_slots": [],
}
narratives = generate_explanations(result.get("raw_text", ""), medications)
card_img = render_card(primary)
csv_row = medications_to_csv(medications)
markdown = (
"## ์–ด๋ฅด์‹ ์„ ์œ„ํ•œ ์„ค๋ช…\n"
+ (narratives.get("elderly_narrative") or "- ์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
+ "\n\n## ์–ด๋ฆฐ์ด๋ฅผ ์œ„ํ•œ ์„ค๋ช…\n"
+ (narratives.get("child_narrative") or "- ์„ค๋ช…์„ ์ค€๋น„ํ•˜์ง€ ๋ชปํ–ˆ์Šต๋‹ˆ๋‹ค.")
+ "\n\n> ํ•ญ์ƒ ์˜๋ฃŒ์ง„์˜ ์•ˆ๋‚ด๋ฅผ ์šฐ์„ ํ•˜์„ธ์š”."
)
warnings_md = format_warnings(result.get("warnings", []))
raw_text = result.get("raw_text", "")
json_text = json.dumps(result, ensure_ascii=False, indent=2)
cartoon_image = generate_cartoon_image(narratives.get("image_prompt"))
return json_text, card_img, csv_row, markdown, warnings_md, raw_text, cartoon_image
CUSTOM_CSS = """
body {background: radial-gradient(circle at top left, #f5f0ff 0%, #fff7ec 60%, #ffffff 100%);}
.gradio-container {max-width: 1180px !important; margin: auto; font-family: 'Noto Sans KR', sans-serif;}
.hero {
background: linear-gradient(120deg, rgba(123, 97, 255, 0.12), rgba(255, 207, 117, 0.18));
border-radius: 28px;
padding: 36px 44px;
box-shadow: 0 20px 40px rgba(66, 46, 138, 0.08);
margin-bottom: 32px;
}
.hero h1 {font-size: 2.4rem; font-weight: 700; color: #1f1c3b; margin-bottom: 12px;}
.hero p {color: #514c7b; font-size: 1.05rem; line-height: 1.6; max-width: 640px;}
.glass-panel {background: rgba(255, 255, 255, 0.72); backdrop-filter: blur(18px); border-radius: 26px; padding: 28px; box-shadow: 0 12px 32px rgba(80, 60, 160, 0.12);}
.primary-btn button {background: linear-gradient(120deg, #7c62ff, #ffa74d); border: none; color: white; font-weight: 600; border-radius: 999px; padding: 12px 22px; box-shadow: 0 12px 24px rgba(124, 98, 255, 0.25);}
.primary-btn button:hover {opacity: 0.95; transform: translateY(-1px);}
.output-card {background: rgba(255, 255, 255, 0.88); border-radius: 22px; padding: 24px; box-shadow: inset 0 0 0 1px rgba(124, 98, 255, 0.08), 0 14px 30px rgba(49, 32, 114, 0.12);}
.notice {background: rgba(255, 247, 226, 0.9); border-radius: 18px; padding: 18px; color: #7a4b00; box-shadow: inset 0 0 0 1px rgba(255, 193, 96, 0.3);}
.csv-box textarea {font-family: 'JetBrains Mono', monospace;}
.gr-image {border-radius: 20px !important; box-shadow: 0 10px 20px rgba(60, 40, 120, 0.15);}
.accordion {border-radius: 20px !important;}
"""
HERO_HTML = """
<div class="hero">
<h1>MedCard-KR ยท ์•ฝ๋ด‰ํˆฌ ํ•œ ์ปท์œผ๋กœ ์ดํ•ดํ•˜๋Š” ๋ณต์šฉ ์•ˆ๋‚ด</h1>
<p>Qwen2.5-VL์ด ์•ฝ ๋ด‰ํˆฌ๋ฅผ ์ง์ ‘ ์ฝ๊ณ , ์•ฝ์‚ฌ์ฒ˜๋Ÿผ ์‰ฝ๊ฒŒ ์„ค๋ช…๊ณผ ํ•œ ์ปท ๋งŒํ™”๋ฅผ ํ•จ๊ป˜ ์ œ๊ณตํ•ฉ๋‹ˆ๋‹ค.</p>
</div>
"""
with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
gr.HTML(HERO_HTML)
with gr.Row():
with gr.Column(scale=4, elem_classes=["glass-panel"]):
gr.Markdown("### 1. ์•ฝ ๋ด‰ํˆฌ ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•˜์„ธ์š”")
img_in = gr.Image(type="pil", label="์•ฝ ๋ด‰ํˆฌ/๋ผ๋ฒจ ์‚ฌ์ง„", height=360)
warn_md = gr.Markdown("๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ ์‚ฌ์ง„์„ ์˜ฌ๋ฆฌ๋ฉด ์ธ์‹์ด ์‹œ์ž‘๋ผ์š”.", elem_classes=["notice"])
btn = gr.Button("์ธ์‹ & ์„ค๋ช… ์ƒ์„ฑ", elem_classes=["primary-btn"])
with gr.Column(scale=6, elem_classes=["glass-panel"]):
gr.Markdown("### 2. ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜์„ธ์š”")
explain_md = gr.Markdown("์—ฌ๊ธฐ์— ์•ฝ ์„ค๋ช…์ด ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.", elem_classes=["output-card"])
raw_box = gr.Textbox(label="๋ชจ๋ธ์ด ์ฝ์€ ์›๋ฌธ ํ…์ŠคํŠธ", lines=5, interactive=False)
cartoon_img = gr.Image(type="pil", label="ํ•œ ์ปท ๋งŒํ™”")
card_out = gr.Image(type="pil", label="์ผ์ • ์นด๋“œ(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
csv_box = gr.Textbox(label="CSV(์•ฝ๋ช…,1ํšŒ์šฉ๋Ÿ‰,1์ผํšŸ์ˆ˜,์‹œ๊ฐ„๋Œ€)", lines=2, elem_classes=["csv-box"])
with gr.Accordion("์„ธ๋ถ€ JSON ๊ฒฐ๊ณผ", open=False, elem_classes=["accordion"]):
json_out = gr.Code(label="๋ชจ๋ธ ๋ถ„์„(JSON)")
btn.click(
run_pipeline,
inputs=img_in,
outputs=[json_out, card_out, csv_box, explain_md, warn_md, raw_box, cartoon_img],
)
gr.Markdown(
"> โ„น๏ธ **์ฃผ์˜**: ์ด ์„œ๋น„์Šค๋Š” ์ฐธ๊ณ ์šฉ ๋„๊ตฌ์ด๋ฉฐ, ์‹ค์ œ ๋ณต์•ฝ์€ ๋ฐ˜๋“œ์‹œ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ์— ๋”ฐ๋ผ ์ฃผ์„ธ์š”."
)
if __name__ == "__main__":
demo.queue().launch()