| import json |
| import re |
| from typing import Any, Dict, List, Optional |
|
|
| import gradio as gr |
| import spaces |
| import torch |
| from diffusers import AutoPipelineForText2Image |
| from PIL import Image, ImageDraw |
| from transformers import ( |
| AutoModelForCausalLM, |
| AutoModelForVision2Seq, |
| AutoProcessor, |
| AutoTokenizer, |
| ) |
|
|
| VL_MODEL_ID = "Qwen/Qwen2.5-VL-7B-Instruct" |
| TEXT_MODEL_ID = "Qwen/Qwen2.5-7B-Instruct" |
| IMAGE_MODEL_ID = "black-forest-labs/FLUX.1-schnell" |
|
|
|
|
| def _load_vl_model(): |
| device_map = "auto" if torch.cuda.is_available() else None |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| model = AutoModelForVision2Seq.from_pretrained( |
| VL_MODEL_ID, |
| device_map=device_map, |
| torch_dtype=dtype, |
| trust_remote_code=True, |
| ) |
| if device_map is None: |
| model = model.to(torch.device("cpu")) |
| processor = AutoProcessor.from_pretrained(VL_MODEL_ID, trust_remote_code=True) |
| return model, processor |
|
|
|
|
| VL_MODEL, VL_PROCESSOR = _load_vl_model() |
|
|
|
|
| def _load_text_model(): |
| device_map = "auto" if torch.cuda.is_available() else None |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| model = AutoModelForCausalLM.from_pretrained( |
| TEXT_MODEL_ID, |
| device_map=device_map, |
| torch_dtype=dtype, |
| trust_remote_code=True, |
| ) |
| if device_map is None: |
| model = model.to(torch.device("cpu")) |
| tokenizer = AutoTokenizer.from_pretrained(TEXT_MODEL_ID, trust_remote_code=True) |
| return model, tokenizer |
|
|
|
|
| TEXT_MODEL, TEXT_TOKENIZER = _load_text_model() |
|
|
|
|
| def _load_image_pipeline(): |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| dtype = torch.float16 if torch.cuda.is_available() else torch.float32 |
| pipe = AutoPipelineForText2Image.from_pretrained( |
| IMAGE_MODEL_ID, |
| torch_dtype=dtype, |
| safety_checker=None, |
| ) |
| pipe.to(device) |
| return pipe |
|
|
|
|
| IMAGE_PIPELINE = _load_image_pipeline() |
|
|
|
|
| def _extract_assistant_content(decoded: str) -> str: |
| if "<|im_start|>assistant" in decoded: |
| content = decoded.split("<|im_start|>assistant")[-1] |
| content = content.replace("<|im_end|>", "").strip() |
| return content |
| return decoded.strip() |
|
|
|
|
| def _extract_json_block(text: str) -> Optional[str]: |
| match = re.search(r"\{.*\}", text, re.DOTALL) |
| if not match: |
| return None |
| return match.group(0) |
|
|
|
|
| def _sanitize_list(value: Any) -> List[str]: |
| if isinstance(value, (list, tuple)): |
| return [str(v).strip() for v in value if str(v).strip()] |
| if isinstance(value, str): |
| return [v.strip() for v in re.split(r"[,;]", value) if v.strip()] |
| return [] |
|
|
|
|
| def _sanitize_medication(item: Dict[str, Any]) -> Dict[str, Any]: |
| def _to_str(val: Any) -> str: |
| return "" if val is None else str(val).strip() |
|
|
| times = item.get("times_per_day") |
| if isinstance(times, (int, float)): |
| times_str = str(int(times)) if float(times).is_integer() else str(times) |
| else: |
| times_str = _to_str(times) |
|
|
| return { |
| "name": _to_str(item.get("name")), |
| "dose_per_intake": _to_str(item.get("dose_per_intake")), |
| "times_per_day": times_str, |
| "time_slots": _sanitize_list(item.get("time_slots")), |
| "description": _to_str(item.get("description")), |
| "usage_example": _to_str(item.get("usage_example")), |
| "dosage_example": _to_str(item.get("dosage_example")), |
| "side_effects": _to_str(item.get("side_effects")), |
| "warnings": _to_str(item.get("warnings")), |
| } |
|
|
|
|
| def _parse_vl_response(text: str) -> Dict[str, Any]: |
| json_block = _extract_json_block(text) |
| if not json_block: |
| return { |
| "raw_text": "", |
| "medications": [], |
| "warnings": ["λͺ¨λΈ μλ΅μμ JSON νμμ μ°Ύμ§ λͺ»νμ΅λλ€."] + ([text.strip()] if text.strip() else []), |
| } |
| try: |
| data = json.loads(json_block) |
| except json.JSONDecodeError: |
| return { |
| "raw_text": "", |
| "medications": [], |
| "warnings": ["λͺ¨λΈ JSON νμ± μ€ν¨", text.strip()], |
| } |
|
|
| meds_raw = data.get("medications") or [] |
| medications: List[Dict[str, Any]] = [] |
| if isinstance(meds_raw, list): |
| for item in meds_raw: |
| if isinstance(item, dict): |
| medications.append(_sanitize_medication(item)) |
|
|
| warnings_raw = data.get("warnings") |
| if isinstance(warnings_raw, list): |
| warnings = [str(w).strip() for w in warnings_raw if str(w).strip()] |
| elif warnings_raw: |
| warnings = [str(warnings_raw).strip()] |
| else: |
| warnings = [] |
|
|
| return { |
| "raw_text": str(data.get("raw_text", "")).strip(), |
| "medications": medications, |
| "warnings": warnings, |
| } |
|
|
|
|
| @spaces.GPU(enable_queue=True) |
| def analyze_image_with_qwen(image: Image.Image) -> Dict[str, Any]: |
| instructions = ( |
| "μ¬μ§ μ μ½λ΄ν¬/μ²λ°©μ μ μ½κ³ μλ JSON νμμΌλ‘λ§ λ΅λ³νμΈμ. " |
| "ν
μ€νΈ μΈμ μ€λͺ
μ΄λ μΆκ° λ¬Έμ₯μ μ λ λ£μ§ λ§μΈμ." |
| ) |
| schema = ( |
| "{\n" |
| " \"raw_text\": \"OCRλ‘ μ½μ μ 체 λ¬Έμ₯\",\n" |
| " \"medications\": [\n" |
| " {\n" |
| " \"name\": \"μ½ μ΄λ¦\",\n" |
| " \"dose_per_intake\": \"1ν μ©λ (μ: 1μ , 5mL)\",\n" |
| " \"times_per_day\": \"ν루 λ³΅μ© νμ\",\n" |
| " \"time_slots\": [\"λ³΅μ© μκ°λ\"],\n" |
| " \"description\": \"μ½ μ€λͺ
\",\n" |
| " \"usage_example\": \"λ³΅μ© μμ\",\n" |
| " \"dosage_example\": \"λ³΅μ© λ°©λ² μμ\",\n" |
| " \"side_effects\": \"μ£Όμ λΆμμ©\",\n" |
| " \"warnings\": \"μ£Όμ 문ꡬ\"\n" |
| " }\n" |
| " ],\n" |
| " \"warnings\": [\"μ 체 κ²½κ³ \"]\n" |
| "}" |
| ) |
| user_prompt = ( |
| "μ JSON μ€ν€λ§λ₯Ό λ°λμ λ°λ₯΄μΈμ. λͺ¨λ κ°μ νκ΅μ΄λ‘ μμ±νκ³ , λΉ μ 보λ λΉ λ¬Έμμ΄λ‘ λμΈμ." |
| ) |
|
|
| messages = [ |
| { |
| "role": "system", |
| "content": "λΉμ μ μ½μ¬ μ μλμ
λλ€. μ ννκ³ μΉμ νκ² μ 보λ₯Ό μ 리νμΈμ.", |
| }, |
| { |
| "role": "user", |
| "content": [ |
| {"type": "text", "text": instructions}, |
| {"type": "text", "text": schema}, |
| {"type": "text", "text": user_prompt}, |
| {"type": "image"}, |
| ], |
| }, |
| ] |
|
|
| chat_text = VL_PROCESSOR.apply_chat_template(messages, add_generation_prompt=True) |
| inputs = VL_PROCESSOR(text=[chat_text], images=[image], return_tensors="pt").to(VL_MODEL.device) |
|
|
| output_ids = VL_MODEL.generate( |
| **inputs, |
| max_new_tokens=1024, |
| temperature=0.1, |
| top_p=0.9, |
| do_sample=False, |
| ) |
|
|
| decoded = VL_PROCESSOR.batch_decode(output_ids, skip_special_tokens=False)[0] |
| assistant_text = _extract_assistant_content(decoded) |
| return _parse_vl_response(assistant_text) |
|
|
|
|
| @spaces.GPU(enable_queue=True) |
| def generate_explanations(raw_text: str, medications: List[Dict[str, Any]]) -> Dict[str, str]: |
| med_summary_lines = [] |
| for med in medications: |
| summary = f"- {med.get('name', 'μ΄λ¦ λ―ΈνμΈ')} {med.get('dose_per_intake', '')}" |
| med_summary_lines.append(summary.strip()) |
| med_summary = "\n".join(med_summary_lines) |
|
|
| system_prompt = "λΉμ μ νμ κ΅μ‘ μ λ¬Έ μ½μ¬μ
λλ€. μ΄λ₯΄μ κ³Ό μ΄λ¦°μ΄μκ² μ½μ μ½κ³ μΉμ νκ² μ€λͺ
νλ©°, λ³΅μ© λ°©λ²κ³Ό μ£Όμμ¬νμ λͺ
νν μ λ¬ν©λλ€." |
| user_prompt = ( |
| "λ€μ μ½ μ 보λ₯Ό λ°νμΌλ‘ μ΄λ₯΄μ κ³Ό μ΄λ¦°μ΄λ₯Ό μν λ³΅μ½ μλ΄λ₯Ό μμ±νμΈμ.\n\n" |
| f"μ½ λͺ©λ‘:\n{med_summary}\n\nμλ¬Έ:\n{raw_text}\n\n" |
| "JSON νμμΌλ‘ λ΅λ³νμΈμ:\n" |
| "{\n" |
| ' "elderly": {\n' |
| ' "narrative": "μ΄λ₯΄μ κ» λ리λ μ€λͺ
(μ‘΄λλ§, ꡬ체μ λ³΅μ© μκ°κ³Ό λ°©λ², μ£Όμμ¬ν ν¬ν¨, 3-5λ¬Έμ₯)",\n' |
| ' "image_prompt": "detailed cartoon illustration showing elderly person taking medicine with family support, warm pastel colors, professional medical setting, clear and caring atmosphere"\n' |
| " },\n" |
| ' "child": {\n' |
| ' "narrative": "μ΄λ¦°μ΄λ₯Ό μν μ€λͺ
(μ¬μ΄ λ§, μ¬λ―Έμκ², μ λ¨Ήμ΄μΌ νλμ§ μ€λͺ
, 3-5λ¬Έμ₯)",\n' |
| ' "image_prompt": "cheerful illustrated cartoon of child taking medicine with parent helping, colorful and friendly, encouraging atmosphere, high quality digital art"\n' |
| " }\n" |
| "}\n\n" |
| "narrativeλ λ°λμ νκ΅μ΄λ‘, image_promptλ λ°λμ μμ΄λ‘ μμ±νμΈμ. " |
| "image_promptλ ꡬ체μ μ΄κ³ μμΈνκ² μ₯λ©΄μ λ¬μ¬νμΈμ." |
| ) |
|
|
| messages = [ |
| {"role": "system", "content": system_prompt}, |
| {"role": "user", "content": user_prompt}, |
| ] |
|
|
| input_ids = TEXT_TOKENIZER.apply_chat_template( |
| messages, |
| add_generation_prompt=True, |
| return_tensors="pt", |
| ).to(TEXT_MODEL.device) |
|
|
| with torch.no_grad(): |
| output_ids = TEXT_MODEL.generate( |
| input_ids, |
| max_new_tokens=768, |
| temperature=0.7, |
| top_p=0.9, |
| do_sample=True, |
| ) |
|
|
| generated_ids = output_ids[0][input_ids.shape[1]:] |
| text = TEXT_TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip() |
|
|
| json_block = _extract_json_block(text) |
| if not json_block: |
| return { |
| "elderly_narrative": "μ€λͺ
μ μ€λΉνμ§ λͺ»νμ΅λλ€. μ½μ¬μκ² μ§μ λ¬ΈμνμΈμ.", |
| "child_narrative": "μ€λͺ
μ μ€λΉνμ§ λͺ»νμ΅λλ€. μ½μ¬μκ² μ§μ λ¬ΈμνμΈμ.", |
| "image_prompt": "single panel cartoon pharmacist helping family, soft colors", |
| } |
|
|
| try: |
| data = json.loads(json_block) |
| except json.JSONDecodeError: |
| return { |
| "elderly_narrative": "μ€λͺ
μ μ€λΉνμ§ λͺ»νμ΅λλ€. μ½μ¬μκ² μ§μ λ¬ΈμνμΈμ.", |
| "child_narrative": "μ€λͺ
μ μ€λΉνμ§ λͺ»νμ΅λλ€. μ½μ¬μκ² μ§μ λ¬ΈμνμΈμ.", |
| "image_prompt": "single panel cartoon pharmacist helping family, soft colors", |
| } |
|
|
| elderly = data.get("elderly", {}) |
| child = data.get("child", {}) |
|
|
| return { |
| "elderly_narrative": str(elderly.get("narrative", "")).strip(), |
| "child_narrative": str(child.get("narrative", "")).strip(), |
| "image_prompt": str(child.get("image_prompt") or elderly.get("image_prompt") or "single panel cartoon pharmacist helping family, pastel colors").strip(), |
| } |
|
|
|
|
| @spaces.GPU(enable_queue=True) |
| def generate_cartoon_image(prompt: str) -> Image.Image: |
| if not prompt: |
| prompt = "wholesome illustrated cartoon scene, friendly pharmacist explaining medicine to elderly and children, warm soft pastel colors, professional medical setting, gentle and caring atmosphere, high quality digital illustration" |
|
|
| enhanced_prompt = f"high quality illustration, {prompt}, soft lighting, detailed, professional artwork, clean composition" |
|
|
| image = IMAGE_PIPELINE( |
| prompt=enhanced_prompt, |
| num_inference_steps=4, |
| guidance_scale=0.0, |
| height=768, |
| width=1024, |
| max_sequence_length=256, |
| ).images[0] |
| return image |
|
|
|
|
| def render_card(primary: Dict[str, Any]) -> Image.Image: |
| width, height = 720, 400 |
| canvas = Image.new("RGB", (width, height), "white") |
| draw = ImageDraw.Draw(canvas) |
|
|
| header = "μ€λ λ³΅μ© μΌμ " |
| draw.rectangle((0, 0, width, 60), fill=(230, 240, 255)) |
| draw.text((24, 18), header, fill=(0, 0, 0)) |
|
|
| y = 90 |
|
|
| def add_line(label: str, value: Optional[str]): |
| nonlocal y |
| text_value = value if value else "-" |
| draw.text((24, y), label, fill=(60, 60, 60)) |
| draw.text((200, y), f": {text_value}", fill=(0, 0, 0)) |
| y += 34 |
|
|
| add_line("μ½ μ΄λ¦", primary.get("name")) |
| add_line("1ν μ©λ", primary.get("dose_per_intake")) |
| add_line("1μΌ νμ", primary.get("times_per_day")) |
|
|
| slots = primary.get("time_slots") or [] |
| add_line("μκ°λ", ", ".join(slots) if slots else None) |
|
|
| footer = "β» μλ£μ§ μ²λ°©μ΄ μ°μ μ΄λ©°, λ³Έ μ±μ μλ΄μ©μ
λλ€." |
| draw.text((24, height - 60), footer, fill=(120, 120, 120)) |
| return canvas |
|
|
|
|
| def medications_to_csv(medications: List[Dict[str, Any]]) -> str: |
| if not medications: |
| return "" |
| first = medications[0] |
| row = [ |
| first.get("name", ""), |
| first.get("dose_per_intake", ""), |
| first.get("times_per_day", ""), |
| ";".join(first.get("time_slots") or []), |
| ] |
| return ",".join(row) |
|
|
|
|
| def format_warnings(warnings: List[str]) -> str: |
| if not warnings: |
| return "β
μΈμλ μ λ³΄κ° μΆ©λΆν΄μ. λ³΅μ½ μκ°λ§ μ μ§μΌ μ£ΌμΈμ." |
| lines = ["### νμΈν΄ μ£ΌμΈμ"] |
| for warn in warnings: |
| lines.append(f"- {warn}") |
| lines.append("\n> μλ£μ§μ μ§μκ° κ°μ₯ μ νν©λλ€.") |
| return "\n".join(lines) |
|
|
|
|
| def run_pipeline(image: Optional[Image.Image]): |
| if image is None: |
| return ( |
| "μ΄λ―Έμ§λ₯Ό μ
λ‘λνμΈμ.", |
| None, |
| None, |
| "μ΄λ―Έμ§λ₯Ό λ¨Όμ μ
λ‘λν΄ μ£ΌμΈμ.", |
| "π· μ½ λ΄ν¬ μ¬μ§μ μ¬λ¦¬λ©΄ μΈμμ΄ μμλΌμ.", |
| "", |
| None, |
| ) |
|
|
| result = analyze_image_with_qwen(image) |
|
|
| medications = result.get("medications") or [] |
| primary = medications[0] if medications else { |
| "name": "", |
| "dose_per_intake": "", |
| "times_per_day": "", |
| "time_slots": [], |
| } |
|
|
| narratives = generate_explanations(result.get("raw_text", ""), medications) |
|
|
| card_img = render_card(primary) |
| csv_row = medications_to_csv(medications) |
| markdown = ( |
| "## μ΄λ₯΄μ μ μν μ€λͺ
\n" |
| + (narratives.get("elderly_narrative") or "- μ€λͺ
μ μ€λΉνμ§ λͺ»νμ΅λλ€.") |
| + "\n\n## μ΄λ¦°μ΄λ₯Ό μν μ€λͺ
\n" |
| + (narratives.get("child_narrative") or "- μ€λͺ
μ μ€λΉνμ§ λͺ»νμ΅λλ€.") |
| + "\n\n> νμ μλ£μ§μ μλ΄λ₯Ό μ°μ νμΈμ." |
| ) |
| warnings_md = format_warnings(result.get("warnings", [])) |
| raw_text = result.get("raw_text", "") |
| json_text = json.dumps(result, ensure_ascii=False, indent=2) |
| cartoon_image = generate_cartoon_image(narratives.get("image_prompt")) |
|
|
| return json_text, card_img, csv_row, markdown, warnings_md, raw_text, cartoon_image |
|
|
|
|
| CUSTOM_CSS = """ |
| body {background: radial-gradient(circle at top left, #f5f0ff 0%, #fff7ec 60%, #ffffff 100%);} |
| .gradio-container {max-width: 1180px !important; margin: auto; font-family: 'Noto Sans KR', sans-serif;} |
| .hero { |
| background: linear-gradient(120deg, rgba(123, 97, 255, 0.12), rgba(255, 207, 117, 0.18)); |
| border-radius: 28px; |
| padding: 36px 44px; |
| box-shadow: 0 20px 40px rgba(66, 46, 138, 0.08); |
| margin-bottom: 32px; |
| } |
| .hero h1 {font-size: 2.4rem; font-weight: 700; color: #1f1c3b; margin-bottom: 12px;} |
| .hero p {color: #514c7b; font-size: 1.05rem; line-height: 1.6; max-width: 640px;} |
| .glass-panel {background: rgba(255, 255, 255, 0.72); backdrop-filter: blur(18px); border-radius: 26px; padding: 28px; box-shadow: 0 12px 32px rgba(80, 60, 160, 0.12);} |
| .primary-btn button {background: linear-gradient(120deg, #7c62ff, #ffa74d); border: none; color: white; font-weight: 600; border-radius: 999px; padding: 12px 22px; box-shadow: 0 12px 24px rgba(124, 98, 255, 0.25);} |
| .primary-btn button:hover {opacity: 0.95; transform: translateY(-1px);} |
| .output-card {background: rgba(255, 255, 255, 0.88); border-radius: 22px; padding: 24px; box-shadow: inset 0 0 0 1px rgba(124, 98, 255, 0.08), 0 14px 30px rgba(49, 32, 114, 0.12);} |
| .notice {background: rgba(255, 247, 226, 0.9); border-radius: 18px; padding: 18px; color: #7a4b00; box-shadow: inset 0 0 0 1px rgba(255, 193, 96, 0.3);} |
| .csv-box textarea {font-family: 'JetBrains Mono', monospace;} |
| .gr-image {border-radius: 20px !important; box-shadow: 0 10px 20px rgba(60, 40, 120, 0.15);} |
| .accordion {border-radius: 20px !important;} |
| """ |
|
|
| HERO_HTML = """ |
| <div class="hero"> |
| <h1>MedCard-KR Β· μ½λ΄ν¬ ν μ»·μΌλ‘ μ΄ν΄νλ λ³΅μ© μλ΄</h1> |
| <p>Qwen2.5-VLμ΄ μ½ λ΄ν¬λ₯Ό μ§μ μ½κ³ , μ½μ¬μ²λΌ μ½κ² μ€λͺ
κ³Ό ν μ»· λ§νλ₯Ό ν¨κ» μ 곡ν©λλ€.</p> |
| </div> |
| """ |
|
|
|
|
| with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: |
| gr.HTML(HERO_HTML) |
| with gr.Row(): |
| with gr.Column(scale=4, elem_classes=["glass-panel"]): |
| gr.Markdown("### 1. μ½ λ΄ν¬ μ¬μ§μ μ
λ‘λνμΈμ") |
| img_in = gr.Image(type="pil", label="μ½ λ΄ν¬/λΌλ²¨ μ¬μ§", height=360) |
| warn_md = gr.Markdown("π· μ½ λ΄ν¬ μ¬μ§μ μ¬λ¦¬λ©΄ μΈμμ΄ μμλΌμ.", elem_classes=["notice"]) |
| btn = gr.Button("μΈμ & μ€λͺ
μμ±", elem_classes=["primary-btn"]) |
| with gr.Column(scale=6, elem_classes=["glass-panel"]): |
| gr.Markdown("### 2. κ²°κ³Όλ₯Ό νμΈνμΈμ") |
| explain_md = gr.Markdown("μ¬κΈ°μ μ½ μ€λͺ
μ΄ νμλ©λλ€.", elem_classes=["output-card"]) |
| raw_box = gr.Textbox(label="λͺ¨λΈμ΄ μ½μ μλ¬Έ ν
μ€νΈ", lines=5, interactive=False) |
| cartoon_img = gr.Image(type="pil", label="ν μ»· λ§ν") |
| card_out = gr.Image(type="pil", label="μΌμ μΉ΄λ(미리보기)") |
| csv_box = gr.Textbox(label="CSV(μ½λͺ
,1νμ©λ,1μΌνμ,μκ°λ)", lines=2, elem_classes=["csv-box"]) |
| with gr.Accordion("μΈλΆ JSON κ²°κ³Ό", open=False, elem_classes=["accordion"]): |
| json_out = gr.Code(label="λͺ¨λΈ λΆμ(JSON)") |
|
|
| btn.click( |
| run_pipeline, |
| inputs=img_in, |
| outputs=[json_out, card_out, csv_box, explain_md, warn_md, raw_box, cartoon_img], |
| ) |
|
|
| gr.Markdown( |
| "> βΉοΈ **μ£Όμ**: μ΄ μλΉμ€λ μ°Έκ³ μ© λꡬμ΄λ©°, μ€μ 볡μ½μ λ°λμ μμ¬Β·μ½μ¬μ μ§μμ λ°λΌ μ£ΌμΈμ." |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| demo.queue().launch() |
|
|