MedCard / app.py
LLDDWW's picture
feat: upgrade models and improve quality
8a13800
raw
history blame
18.1 kB
import json
import re
from typing import Any, Dict, List, Optional
import gradio as gr
import spaces
import torch
from diffusers import AutoPipelineForText2Image
from PIL import Image, ImageDraw
from transformers import (
AutoModelForCausalLM,
AutoModelForVision2Seq,
AutoProcessor,
AutoTokenizer,
)
VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct"
TEXT_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct"
IMAGE_MODEL_ID = "black-forest-labs/FLUX.1-schnell"
def _load_vl_model():
device_map = "auto" if torch.cuda.is_available() else None
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForVision2Seq.from_pretrained(
VL_MODEL_ID,
device_map=device_map,
torch_dtype=dtype,
trust_remote_code=True,
)
if device_map is None:
model = model.to(torch.device("cpu"))
processor = AutoProcessor.from_pretrained(VL_MODEL_ID, trust_remote_code=True)
return model, processor
VL_MODEL, VL_PROCESSOR = _load_vl_model()
def _load_text_model():
device_map = "auto" if torch.cuda.is_available() else None
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model = AutoModelForCausalLM.from_pretrained(
TEXT_MODEL_ID,
device_map=device_map,
torch_dtype=dtype,
trust_remote_code=True,
)
if device_map is None:
model = model.to(torch.device("cpu"))
tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID, trust_remote_code=True)
return model, tokenizer
TEXT_MODEL, TEXT_TOKENIZER = _load_text_model()
def _load_image_pipeline():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
pipe = AutoPipelineForText2Image.from_pretrained(
IMAGE_MODEL_ID,
torch_dtype=dtype,
safety_checker=None,
)
pipe.to(device)
return pipe
IMAGE_PIPELINE = _load_image_pipeline()
def _extract_assistant_content(decoded: str) -> str:
if "<|im_start|>assistant" in decoded:
content = decoded.split("<|im_start|>assistant")[-1]
content = content.replace("<|im_end|>", "").strip()
return content
return decoded.strip()
def _extract_json_block(text: str) -> Optional[str]:
match = re.search(r"\{.*\}", text, re.DOTALL)
if not match:
return None
return match.group(0)
def _sanitize_list(value: Any) -> List[str]:
if isinstance(value, (list, tuple)):
return [str(v).strip() for v in value if str(v).strip()]
if isinstance(value, str):
return [v.strip() for v in re.split(r"[,;]", value) if v.strip()]
return []
def _sanitize_medication(item: Dict[str, Any]) -> Dict[str, Any]:
def _to_str(val: Any) -> str:
return "" if val is None else str(val).strip()
times = item.get("times_per_day")
if isinstance(times, (int, float)):
times_str = str(int(times)) if float(times).is_integer() else str(times)
else:
times_str = _to_str(times)
return {
"name": _to_str(item.get("name")),
"dose_per_intake": _to_str(item.get("dose_per_intake")),
"times_per_day": times_str,
"time_slots": _sanitize_list(item.get("time_slots")),
"description": _to_str(item.get("description")),
"usage_example": _to_str(item.get("usage_example")),
"dosage_example": _to_str(item.get("dosage_example")),
"side_effects": _to_str(item.get("side_effects")),
"warnings": _to_str(item.get("warnings")),
}
def _parse_vl_response(text: str) -> Dict[str, Any]:
json_block = _extract_json_block(text)
if not json_block:
return {
"raw_text": "",
"medications": [],
"warnings": ["λͺ¨λΈ μ‘λ‹΅μ—μ„œ JSON ν˜•μ‹μ„ μ°Ύμ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€."] + ([text.strip()] if text.strip() else []),
}
try:
data = json.loads(json_block)
except json.JSONDecodeError:
return {
"raw_text": "",
"medications": [],
"warnings": ["λͺ¨λΈ JSON νŒŒμ‹± μ‹€νŒ¨", text.strip()],
}
meds_raw = data.get("medications") or []
medications: List[Dict[str, Any]] = []
if isinstance(meds_raw, list):
for item in meds_raw:
if isinstance(item, dict):
medications.append(_sanitize_medication(item))
warnings_raw = data.get("warnings")
if isinstance(warnings_raw, list):
warnings = [str(w).strip() for w in warnings_raw if str(w).strip()]
elif warnings_raw:
warnings = [str(warnings_raw).strip()]
else:
warnings = []
return {
"raw_text": str(data.get("raw_text", "")).strip(),
"medications": medications,
"warnings": warnings,
}
@spaces.GPU(enable_queue=True)
def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]:
instructions = (
"사진 속 μ•½λ΄‰νˆ¬/μ²˜λ°©μ „μ„ 읽고 μ•„λž˜ JSON ν˜•μ‹μœΌλ‘œλ§Œ λ‹΅λ³€ν•˜μ„Έμš”. "
"ν…μŠ€νŠΈ μ™Έμ˜ μ„€λͺ…μ΄λ‚˜ μΆ”κ°€ λ¬Έμž₯은 μ ˆλŒ€ λ„£μ§€ λ§ˆμ„Έμš”."
)
schema = (
"{\n"
" \"raw_text\": \"OCR둜 읽은 전체 λ¬Έμž₯\",\n"
" \"medications\": [\n"
" {\n"
" \"name\": \"μ•½ 이름\",\n"
" \"dose_per_intake\": \"1회 μš©λŸ‰ (예: 1μ •, 5mL)\",\n"
" \"times_per_day\": \"ν•˜λ£¨ 볡용 횟수\",\n"
" \"time_slots\": [\"볡용 μ‹œκ°„λŒ€\"],\n"
" \"description\": \"μ•½ μ„€λͺ…\",\n"
" \"usage_example\": \"볡용 μ˜ˆμ‹œ\",\n"
" \"dosage_example\": \"볡용 방법 μ˜ˆμ‹œ\",\n"
" \"side_effects\": \"μ£Όμš” λΆ€μž‘μš©\",\n"
" \"warnings\": \"주의 문ꡬ\"\n"
" }\n"
" ],\n"
" \"warnings\": [\"전체 κ²½κ³ \"]\n"
"}"
)
user_prompt = (
"μœ„ JSON μŠ€ν‚€λ§ˆλ₯Ό λ°˜λ“œμ‹œ λ”°λ₯΄μ„Έμš”. λͺ¨λ“  값은 ν•œκ΅­μ–΄λ‘œ μž‘μ„±ν•˜κ³ , 빈 μ •λ³΄λŠ” 빈 λ¬Έμžμ—΄λ‘œ λ‘μ„Έμš”."
)
messages = [
{
"role": "system",
"content": "당신은 약사 μ„ μƒλ‹˜μž…λ‹ˆλ‹€. μ •ν™•ν•˜κ³  μΉœμ ˆν•˜κ²Œ 정보λ₯Ό μ •λ¦¬ν•˜μ„Έμš”.",
},
{
"role": "user",
"content": [
{"type": "text", "text": instructions},
{"type": "text", "text": schema},
{"type": "text", "text": user_prompt},
{"type": "image"},
],
},
]
chat_text = VL_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True)
inputs = VL_PROCESSOR(text=[chat_text], images=[image], return_tensors="pt").to(VL_MODEL.device)
output_ids = VL_MODEL.generate(
**inputs,
max_new_tokens=1024,
temperature=0.1,
top_p=0.9,
do_sample=False,
)
decoded = VL_PROCESSOR.batch_decode(output_ids, skip_special_tokens=False)[0]
assistant_text = _extract_assistant_content(decoded)
return _parse_vl_response(assistant_text)
@spaces.GPU(enable_queue=True)
def generate_explanations(raw_text: str, medications: List[Dict[str, Any]]) -> Dict[str, str]:
med_summary_lines = []
for med in medications:
summary = f"- {med.get('name', '이름 미확인')} {med.get('dose_per_intake', '')}"
med_summary_lines.append(summary.strip())
med_summary = "\n".join(med_summary_lines)
system_prompt = "당신은 ν™˜μž ꡐ윑 μ „λ¬Έ μ•½μ‚¬μž…λ‹ˆλ‹€. μ–΄λ₯΄μ‹ κ³Ό μ–΄λ¦°μ΄μ—κ²Œ 약을 쉽고 μΉœμ ˆν•˜κ²Œ μ„€λͺ…ν•˜λ©°, 볡용 방법과 μ£Όμ˜μ‚¬ν•­μ„ λͺ…ν™•νžˆ μ „λ‹¬ν•©λ‹ˆλ‹€."
user_prompt = (
"λ‹€μŒ μ•½ 정보λ₯Ό λ°”νƒ•μœΌλ‘œ μ–΄λ₯΄μ‹ κ³Ό 어린이λ₯Ό μœ„ν•œ 볡약 μ•ˆλ‚΄λ₯Ό μž‘μ„±ν•˜μ„Έμš”.\n\n"
f"μ•½ λͺ©λ‘:\n{med_summary}\n\n원문:\n{raw_text}\n\n"
"JSON ν˜•μ‹μœΌλ‘œ λ‹΅λ³€ν•˜μ„Έμš”:\n"
"{\n"
' "elderly": {\n'
' "narrative": "μ–΄λ₯΄μ‹ κ»˜ λ“œλ¦¬λŠ” μ„€λͺ… (μ‘΄λŒ“λ§, ꡬ체적 볡용 μ‹œκ°„κ³Ό 방법, μ£Όμ˜μ‚¬ν•­ 포함, 3-5λ¬Έμž₯)",\n'
' "image_prompt": "detailed cartoon illustration showing elderly person taking medicine with family support, warm pastel colors, professional medical setting, clear and caring atmosphere"\n'
" },\n"
' "child": {\n'
' "narrative": "어린이λ₯Ό μœ„ν•œ μ„€λͺ… (μ‰¬μš΄ 말, 재미있게, μ™œ λ¨Ήμ–΄μ•Ό ν•˜λŠ”μ§€ μ„€λͺ…, 3-5λ¬Έμž₯)",\n'
' "image_prompt": "cheerful illustrated cartoon of child taking medicine with parent helping, colorful and friendly, encouraging atmosphere, high quality digital art"\n'
" }\n"
"}\n\n"
"narrativeλŠ” λ°˜λ“œμ‹œ ν•œκ΅­μ–΄λ‘œ, image_promptλŠ” λ°˜λ“œμ‹œ μ˜μ–΄λ‘œ μž‘μ„±ν•˜μ„Έμš”. "
"image_promptλŠ” ꡬ체적이고 μƒμ„Έν•˜κ²Œ μž₯면을 λ¬˜μ‚¬ν•˜μ„Έμš”."
)
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
]
input_ids = TEXT_TOKENIZER.apply_chat_template(
messages,
add_generation_prompt=True,
return_tensors="pt",
).to(TEXT_MODEL.device)
with torch.no_grad():
output_ids = TEXT_MODEL.generate(
input_ids,
max_new_tokens=768,
temperature=0.7,
top_p=0.9,
do_sample=True,
)
generated_ids = output_ids[0][input_ids.shape[1]:]
text = TEXT_TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip()
json_block = _extract_json_block(text)
if not json_block:
return {
"elderly_narrative": "μ„€λͺ…을 μ€€λΉ„ν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€. μ•½μ‚¬μ—κ²Œ 직접 λ¬Έμ˜ν•˜μ„Έμš”.",
"child_narrative": "μ„€λͺ…을 μ€€λΉ„ν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€. μ•½μ‚¬μ—κ²Œ 직접 λ¬Έμ˜ν•˜μ„Έμš”.",
"image_prompt": "single panel cartoon pharmacist helping family, soft colors",
}
try:
data = json.loads(json_block)
except json.JSONDecodeError:
return {
"elderly_narrative": "μ„€λͺ…을 μ€€λΉ„ν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€. μ•½μ‚¬μ—κ²Œ 직접 λ¬Έμ˜ν•˜μ„Έμš”.",
"child_narrative": "μ„€λͺ…을 μ€€λΉ„ν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€. μ•½μ‚¬μ—κ²Œ 직접 λ¬Έμ˜ν•˜μ„Έμš”.",
"image_prompt": "single panel cartoon pharmacist helping family, soft colors",
}
elderly = data.get("elderly", {})
child = data.get("child", {})
return {
"elderly_narrative": str(elderly.get("narrative", "")).strip(),
"child_narrative": str(child.get("narrative", "")).strip(),
"image_prompt": str(child.get("image_prompt") or elderly.get("image_prompt") or "single panel cartoon pharmacist helping family, pastel colors").strip(),
}
@spaces.GPU(enable_queue=True)
def generate_cartoon_image(prompt: str) -> Image.Image:
if not prompt:
prompt = "wholesome illustrated cartoon scene, friendly pharmacist explaining medicine to elderly and children, warm soft pastel colors, professional medical setting, gentle and caring atmosphere, high quality digital illustration"
enhanced_prompt = f"high quality illustration, {prompt}, soft lighting, detailed, professional artwork, clean composition"
image = IMAGE_PIPELINE(
prompt=enhanced_prompt,
num_inference_steps=4,
guidance_scale=0.0,
height=768,
width=1024,
max_sequence_length=256,
).images[0]
return image
def render_card(primary: Dict[str, Any]) -> Image.Image:
width, height = 720, 400
canvas = Image.new("RGB", (width, height), "white")
draw = ImageDraw.Draw(canvas)
header = "였늘 볡용 일정"
draw.rectangle((0, 0, width, 60), fill=(230, 240, 255))
draw.text((24, 18), header, fill=(0, 0, 0))
y = 90
def add_line(label: str, value: Optional[str]):
nonlocal y
text_value = value if value else "-"
draw.text((24, y), label, fill=(60, 60, 60))
draw.text((200, y), f": {text_value}", fill=(0, 0, 0))
y += 34
add_line("μ•½ 이름", primary.get("name"))
add_line("1회 μš©λŸ‰", primary.get("dose_per_intake"))
add_line("1일 횟수", primary.get("times_per_day"))
slots = primary.get("time_slots") or []
add_line("μ‹œκ°„λŒ€", ", ".join(slots) if slots else None)
footer = "β€» μ˜λ£Œμ§„ 처방이 μš°μ„ μ΄λ©°, λ³Έ 앱은 μ•ˆλ‚΄μš©μž…λ‹ˆλ‹€."
draw.text((24, height - 60), footer, fill=(120, 120, 120))
return canvas
def medications_to_csv(medications: List[Dict[str, Any]]) -> str:
if not medications:
return ""
first = medications[0]
row = [
first.get("name", ""),
first.get("dose_per_intake", ""),
first.get("times_per_day", ""),
";".join(first.get("time_slots") or []),
]
return ",".join(row)
def format_warnings(warnings: List[str]) -> str:
if not warnings:
return "βœ… μΈμ‹λœ 정보가 μΆ©λΆ„ν•΄μš”. 볡약 μ‹œκ°„λ§Œ 잘 μ§€μΌœ μ£Όμ„Έμš”."
lines = ["### 확인해 μ£Όμ„Έμš”"]
for warn in warnings:
lines.append(f"- {warn}")
lines.append("\n> μ˜λ£Œμ§„μ˜ μ§€μ‹œκ°€ κ°€μž₯ μ •ν™•ν•©λ‹ˆλ‹€.")
return "\n".join(lines)
def run_pipeline(image: Optional[Image.Image]):
if image is None:
return (
"이미지λ₯Ό μ—…λ‘œλ“œν•˜μ„Έμš”.",
None,
None,
"이미지λ₯Ό λ¨Όμ € μ—…λ‘œλ“œν•΄ μ£Όμ„Έμš”.",
"πŸ“· μ•½ λ΄‰νˆ¬ 사진을 올리면 인식이 μ‹œμž‘λΌμš”.",
"",
None,
)
result = analyze_image_with_qwen(image)
medications = result.get("medications") or []
primary = medications[0] if medications else {
"name": "",
"dose_per_intake": "",
"times_per_day": "",
"time_slots": [],
}
narratives = generate_explanations(result.get("raw_text", ""), medications)
card_img = render_card(primary)
csv_row = medications_to_csv(medications)
markdown = (
"## μ–΄λ₯΄μ‹ μ„ μœ„ν•œ μ„€λͺ…\n"
+ (narratives.get("elderly_narrative") or "- μ„€λͺ…을 μ€€λΉ„ν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€.")
+ "\n\n## 어린이λ₯Ό μœ„ν•œ μ„€λͺ…\n"
+ (narratives.get("child_narrative") or "- μ„€λͺ…을 μ€€λΉ„ν•˜μ§€ λͺ»ν–ˆμŠ΅λ‹ˆλ‹€.")
+ "\n\n> 항상 μ˜λ£Œμ§„μ˜ μ•ˆλ‚΄λ₯Ό μš°μ„ ν•˜μ„Έμš”."
)
warnings_md = format_warnings(result.get("warnings", []))
raw_text = result.get("raw_text", "")
json_text = json.dumps(result, ensure_ascii=False, indent=2)
cartoon_image = generate_cartoon_image(narratives.get("image_prompt"))
return json_text, card_img, csv_row, markdown, warnings_md, raw_text, cartoon_image
CUSTOM_CSS = """
body {background: radial-gradient(circle at top left, #f5f0ff 0%, #fff7ec 60%, #ffffff 100%);}
.gradio-container {max-width: 1180px !important; margin: auto; font-family: 'Noto Sans KR', sans-serif;}
.hero {
background: linear-gradient(120deg, rgba(123, 97, 255, 0.12), rgba(255, 207, 117, 0.18));
border-radius: 28px;
padding: 36px 44px;
box-shadow: 0 20px 40px rgba(66, 46, 138, 0.08);
margin-bottom: 32px;
}
.hero h1 {font-size: 2.4rem; font-weight: 700; color: #1f1c3b; margin-bottom: 12px;}
.hero p {color: #514c7b; font-size: 1.05rem; line-height: 1.6; max-width: 640px;}
.glass-panel {background: rgba(255, 255, 255, 0.72); backdrop-filter: blur(18px); border-radius: 26px; padding: 28px; box-shadow: 0 12px 32px rgba(80, 60, 160, 0.12);}
.primary-btn button {background: linear-gradient(120deg, #7c62ff, #ffa74d); border: none; color: white; font-weight: 600; border-radius: 999px; padding: 12px 22px; box-shadow: 0 12px 24px rgba(124, 98, 255, 0.25);}
.primary-btn button:hover {opacity: 0.95; transform: translateY(-1px);}
.output-card {background: rgba(255, 255, 255, 0.88); border-radius: 22px; padding: 24px; box-shadow: inset 0 0 0 1px rgba(124, 98, 255, 0.08), 0 14px 30px rgba(49, 32, 114, 0.12);}
.notice {background: rgba(255, 247, 226, 0.9); border-radius: 18px; padding: 18px; color: #7a4b00; box-shadow: inset 0 0 0 1px rgba(255, 193, 96, 0.3);}
.csv-box textarea {font-family: 'JetBrains Mono', monospace;}
.gr-image {border-radius: 20px !important; box-shadow: 0 10px 20px rgba(60, 40, 120, 0.15);}
.accordion {border-radius: 20px !important;}
"""
HERO_HTML = """
<div class="hero">
<h1>MedCard-KR Β· μ•½λ΄‰νˆ¬ ν•œ 컷으둜 μ΄ν•΄ν•˜λŠ” 볡용 μ•ˆλ‚΄</h1>
<p>Qwen2.5-VL이 μ•½ λ΄‰νˆ¬λ₯Ό 직접 읽고, μ•½μ‚¬μ²˜λŸΌ μ‰½κ²Œ μ„€λͺ…κ³Ό ν•œ μ»· λ§Œν™”λ₯Ό ν•¨κ»˜ μ œκ³΅ν•©λ‹ˆλ‹€.</p>
</div>
"""
with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
gr.HTML(HERO_HTML)
with gr.Row():
with gr.Column(scale=4, elem_classes=["glass-panel"]):
gr.Markdown("### 1. μ•½ λ΄‰νˆ¬ 사진을 μ—…λ‘œλ“œν•˜μ„Έμš”")
img_in = gr.Image(type="pil", label="μ•½ λ΄‰νˆ¬/라벨 사진", height=360)
warn_md = gr.Markdown("πŸ“· μ•½ λ΄‰νˆ¬ 사진을 올리면 인식이 μ‹œμž‘λΌμš”.", elem_classes=["notice"])
btn = gr.Button("인식 & μ„€λͺ… 생성", elem_classes=["primary-btn"])
with gr.Column(scale=6, elem_classes=["glass-panel"]):
gr.Markdown("### 2. κ²°κ³Όλ₯Ό ν™•μΈν•˜μ„Έμš”")
explain_md = gr.Markdown("여기에 μ•½ μ„€λͺ…이 ν‘œμ‹œλ©λ‹ˆλ‹€.", elem_classes=["output-card"])
raw_box = gr.Textbox(label="λͺ¨λΈμ΄ 읽은 원문 ν…μŠ€νŠΈ", lines=5, interactive=False)
cartoon_img = gr.Image(type="pil", label="ν•œ μ»· λ§Œν™”")
card_out = gr.Image(type="pil", label="일정 μΉ΄λ“œ(미리보기)")
csv_box = gr.Textbox(label="CSV(μ•½λͺ…,1νšŒμš©λŸ‰,1일횟수,μ‹œκ°„λŒ€)", lines=2, elem_classes=["csv-box"])
with gr.Accordion("μ„ΈλΆ€ JSON κ²°κ³Ό", open=False, elem_classes=["accordion"]):
json_out = gr.Code(label="λͺ¨λΈ 뢄석(JSON)")
btn.click(
run_pipeline,
inputs=img_in,
outputs=[json_out, card_out, csv_box, explain_md, warn_md, raw_box, cartoon_img],
)
gr.Markdown(
"> ℹ️ **주의**: 이 μ„œλΉ„μŠ€λŠ” 참고용 도ꡬ이며, μ‹€μ œ 볡약은 λ°˜λ“œμ‹œ μ˜μ‚¬Β·μ•½μ‚¬μ˜ μ§€μ‹œμ— 따라 μ£Όμ„Έμš”."
)
if __name__ == "__main__":
demo.queue().launch()