import json import re from typing import List, Optional, Tuple, Union import numpy as np import os import gradio as gr import spaces import torch from PIL import Image from transformers import AutoTokenizer, AutoModelForCausalLM from huggingface_hub import login, snapshot_download from paddleocr import PaddleOCR # Hugging Face 토큰으로 로그인 (Spaces Secret에서 가져옴) HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN: login(token=HF_TOKEN.strip()) # 약 정보 분석 모델 ID (빠른 추론을 위해 경량 모델 사용) MED_MODEL_ID = "google/gemma-2-2b-it" # 전역 모델 변수 (한 번만 로드) OCR_READER = None MED_MODEL = None MED_TOKENIZER = None OCR_MODEL_REPO_ID = "PaddlePaddle/korean_PP-OCRv5_mobile_rec" def _collect_ocr_texts(ocr_payload) -> List[str]: """PaddleOCR 결과 구조에서 텍스트만 추출""" texts: List[str] = [] seen = set() def add_text(candidate: str): if not isinstance(candidate, str): return normalized = candidate.strip() if normalized and normalized not in seen: seen.add(normalized) texts.append(normalized) def walk(node): if isinstance(node, str): add_text(node) return if isinstance(node, dict): for key in ("text", "label", "transcription"): add_text(node.get(key)) for key in ("texts", "labels"): values = node.get(key) if isinstance(values, (list, tuple)): for value in values: add_text(value) for key in ("text_recognition", "rec_results", "data", "results"): if key in node: walk(node[key]) return if isinstance(node, (list, tuple)): if len(node) >= 2: second = node[1] if isinstance(second, str): add_text(second) elif isinstance(second, (list, tuple)) and second: maybe_text = second[0] add_text(maybe_text) for item in node: walk(item) walk(ocr_payload) return texts def load_models(): """모델들을 한 번만 로드""" global OCR_READER, MED_MODEL, MED_TOKENIZER if OCR_READER is None: print("🔄 Loading PaddleOCR (Korean PP-OCRv5 mobile recognition)...") rec_model_dir = snapshot_download( OCR_MODEL_REPO_ID, allow_patterns=[ "*.pdmodel", "*.pdiparams", "*.pdparams", "*.json", "*.yml", ], ) OCR_READER = PaddleOCR( lang='korean', use_textline_orientation=True, text_recognition_model_dir=rec_model_dir, text_recognition_model_name="korean_PP-OCRv5_mobile_rec", ) print("✅ PaddleOCR loaded!") if MED_MODEL is None: print("🔄 Loading Gemma-2-2B for medical analysis (8bit quantization)...") MED_MODEL = AutoModelForCausalLM.from_pretrained( MED_MODEL_ID, torch_dtype=torch.bfloat16, device_map="auto", load_in_8bit=True ) MED_TOKENIZER = AutoTokenizer.from_pretrained(MED_MODEL_ID) print("✅ Medical model loaded!") # 앱 시작 시 모델 로드 load_models() def _extract_assistant_content(decoded: str) -> str: """어시스턴트 응답 추출""" if "<|im_start|>assistant" in decoded: content = decoded.split("<|im_start|>assistant")[-1] content = content.replace("<|im_end|>", "").strip() return content return decoded.strip() def _extract_json_block(text: str) -> Optional[str]: """JSON 블록 추출""" match = re.search(r"\{.*\}", text, re.DOTALL) if not match: return None return match.group(0) @spaces.GPU(duration=120) def analyze_medication_image(image: Image.Image) -> Tuple[str, str]: """이미지에서 OCR 추출 후 약 정보 분석""" import time try: # Step 1: OCR - PaddleOCR로 한글 텍스트 추출 start_time = time.time() img_array = np.array(image) try: ocr_results = OCR_READER.predict(img_array) except (TypeError, AttributeError): ocr_results = OCR_READER.ocr(img_array) ocr_time = time.time() - start_time print(f"⏱️ OCR took {ocr_time:.2f}s") if not ocr_results: return "텍스트를 찾을 수 없습니다.", "" # 텍스트 추출 texts = _collect_ocr_texts(ocr_results) if not texts: return "텍스트를 찾을 수 없습니다.", "" ocr_text = "\n".join(texts) # Step 2: 약 정보 분석 - MedGemma로 의료 정보 제공 analysis_start = time.time() analysis_prompt = f"""다음은 약 봉투나 처방전에서 추출한 텍스트입니다: {ocr_text} 위 텍스트에서 약 이름을 찾아서, 각 약에 대해 **노인과 어린이 모두 쉽게 이해할 수 있도록** 재미있고 친근하게 설명해주세요: 📋 **각 약마다 다음 정보를 포함해주세요:** 1. 💊 **약 이름**: 정확한 약 이름 2. 🎯 **효능**: 이 약이 무엇을 치료하고 어떻게 도움이 되는지 3. ⚠️ **부작용**: 주의해야 할 부작용들 4. 💡 **복용 방법**: 언제, 어떻게 먹어야 하는지 (식전/식후, 하루 몇 번 등) 5. 🚫 **주의사항**: 이 약과 함께 먹으면 안 되는 것들 (음식, 다른 약 등) **스타일 가이드:** - 이모지를 적극 활용하여 재미있게 작성 - 할머니 할아버지나 초등학생도 이해할 수 있는 쉬운 단어 사용 - 각 약마다 구분선으로 구분 - 친근하고 따뜻한 말투 사용 - 마크다운 형식으로 작성 시작해주세요!""" messages = [ {"role": "user", "content": analysis_prompt} ] input_text = MED_TOKENIZER.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) inputs = MED_TOKENIZER(input_text, return_tensors="pt").to(MED_MODEL.device) with torch.no_grad(): outputs = MED_MODEL.generate( **inputs, max_new_tokens=768, temperature=0.7, top_p=0.9, do_sample=True ) analysis_text = MED_TOKENIZER.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) analysis_time = time.time() - analysis_start total_time = time.time() - start_time print(f"⏱️ Medical analysis took {analysis_time:.2f}s") print(f"⏱️ Total processing time: {total_time:.2f}s") return ocr_text.strip(), analysis_text.strip() except Exception as e: raise Exception(f"분석 오류: {str(e)}") def extract_medications_from_text(text: str) -> List[str]: """Stage 2: Qwen2.5로 텍스트에서 약 이름만 추출""" try: messages = [ { "role": "system", "content": "You are a medical text analyzer. Extract only medication names from the given text and return them as a JSON array. Return ONLY valid JSON format." }, { "role": "user", "content": f"Extract all medication names from this text:\n\n{text}\n\nReturn format: {{\"medications\": [\"name1\", \"name2\"]}}" } ] prompt = LLM_TOKENIZER.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = LLM_TOKENIZER(prompt, return_tensors="pt").to(LLM_MODEL.device) with torch.no_grad(): outputs = LLM_MODEL.generate( **inputs, max_new_tokens=512, temperature=0.3, top_p=0.9, do_sample=True, pad_token_id=LLM_TOKENIZER.eos_token_id, ) response = LLM_TOKENIZER.decode(outputs[0], skip_special_tokens=True) # Extract assistant response (Qwen format) if "<|im_start|>assistant" in response: response = response.split("<|im_start|>assistant")[-1] response = response.replace("<|im_end|>", "").strip() # Parse JSON json_match = re.search(r'\{.*?\}', response, re.DOTALL) if json_match: data = json.loads(json_match.group(0)) medications = data.get("medications", []) if isinstance(medications, list) and medications: return [str(m).strip() for m in medications if str(m).strip()] return ["약 이름을 찾지 못했습니다."] except Exception as e: raise Exception(f"LLM 분석 오류: {str(e)}") @spaces.GPU(duration=120) def extract_medication_names(image: Image.Image) -> Tuple[str, List[str]]: """2단계 파이프라인: OCR → LLM 분석""" try: # Stage 1: OCR로 텍스트 추출 extracted_text = extract_text_from_image(image) if not extracted_text: return "", ["텍스트를 추출하지 못했습니다."] # Stage 2: LLM으로 약 이름 추출 medications = extract_medications_from_text(extracted_text) return extracted_text, medications except Exception as e: return "", [f"오류 발생: {str(e)}"] def format_results(extracted_text: str, medications: List[str]) -> Tuple[str, str]: """결과를 포맷팅""" # 추출된 전체 텍스트 text_output = f"### 📄 추출된 텍스트\n\n```\n{extracted_text}\n```" # 약 이름 리스트 if not medications or medications[0].startswith("오류") or medications[0].startswith("약 이름을 찾지") or medications[0].startswith("텍스트를"): med_output = f"### ⚠️ {medications[0] if medications else '약 이름을 찾지 못했습니다.'}" else: med_output = f"### 💊 검출된 약물 ({len(medications)}개)\n\n" for idx, med_name in enumerate(medications, 1): med_output += f"{idx}. **{med_name}**\n" return text_output, med_output def _ensure_pil(image_input: Optional[Union[Image.Image, np.ndarray, str]]) -> Optional[Image.Image]: """Gradio 입력을 PIL 이미지로 변환""" if image_input is None: return None if isinstance(image_input, Image.Image): return image_input if isinstance(image_input, np.ndarray): if image_input.dtype != np.uint8: image_input = np.clip(image_input, 0, 255).astype(np.uint8) return Image.fromarray(image_input).convert("RGB") if isinstance(image_input, str): if not os.path.exists(image_input): return None with Image.open(image_input) as img: return img.convert("RGB") return None def run_analysis(image: Optional[Union[Image.Image, np.ndarray, str]], progress=gr.Progress()): """메인 분석 파이프라인: OCR + 약 정보 분석""" pil_image = _ensure_pil(image) if pil_image is None: return "📷 약 봉투나 처방전 사진을 업로드해주세요.", "" progress(0.3, desc="📸 1단계: OCR 텍스트 추출 중...") progress(0.6, desc="🤖 2단계: 약 정보 분석 중...") try: ocr_text, analysis = analyze_medication_image(pil_image) progress(1.0, desc="✅ 완료!") ocr_output = f"### 📄 추출된 텍스트\n\n```\n{ocr_text}\n```" analysis_output = f"### 💊 약 정보 설명\n\n{analysis}" return ocr_output, analysis_output except Exception as e: return f"### ⚠️ 오류 발생\n\n{str(e)}", "" # 심플한 CSS CUSTOM_CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); :root { --primary: #6366f1; --secondary: #8b5cf6; } body { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; } .gradio-container { max-width: 900px !important; margin: auto; background: rgba(255, 255, 255, 0.98); border-radius: 24px; box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.3); padding: 40px; } .hero { text-align: center; padding: 30px 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 20px; color: white; margin-bottom: 30px; } .hero h1 { font-size: 2.5rem; font-weight: 700; margin-bottom: 10px; } .hero p { font-size: 1.1rem; opacity: 0.95; } .upload-section { background: white; border-radius: 16px; padding: 30px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); margin-bottom: 20px; } .result-section { background: white; border-radius: 16px; padding: 30px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); min-height: 200px; } .analyze-btn button { background: linear-gradient(135deg, var(--primary), var(--secondary)) !important; color: white !important; font-weight: 600 !important; font-size: 1.1rem !important; padding: 18px 40px !important; border-radius: 12px !important; border: none !important; box-shadow: 0 10px 20px -5px rgba(99, 102, 241, 0.5) !important; transition: all 0.3s ease !important; } .analyze-btn button:hover { transform: translateY(-2px) !important; box-shadow: 0 15px 30px -5px rgba(99, 102, 241, 0.6) !important; } .gr-image { border-radius: 12px !important; } """ HERO_HTML = """

💊 우리 가족 약 도우미

약봉투/처방전 사진에서 약 정보를 쉽고 재미있게 알려드려요!

""" # Gradio 인터페이스 with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: gr.HTML(HERO_HTML) with gr.Column(elem_classes=["upload-section"]): gr.Markdown("### 📸 사진 업로드") image_input = gr.Image(type="numpy", image_mode="RGB", label="약봉투 또는 처방전 사진", height=350) analyze_button = gr.Button("🔍 약 정보 분석하기", elem_classes=["analyze-btn"], size="lg") with gr.Row(): with gr.Column(elem_classes=["result-section"]): gr.Markdown("### 📋 1단계: 추출된 텍스트") ocr_output = gr.Markdown("OCR로 추출된 텍스트가 여기 표시됩니다.") with gr.Column(elem_classes=["result-section"]): gr.Markdown("### 📋 2단계: 쉬운 약 설명") analysis_output = gr.Markdown("노인과 어린이도 이해하기 쉬운 약 정보가 여기 표시됩니다.") analyze_button.click( run_analysis, inputs=image_input, outputs=[ocr_output, analysis_output], ) gr.Markdown(""" --- **ℹ️ 사용 방법** 1. 약 봉투나 처방전 사진을 업로드하세요 2. '약 정보 분석하기' 버튼을 클릭하세요 3. 왼쪽에는 추출된 텍스트, 오른쪽에는 쉬운 설명이 나타납니다! **⚠️ 주의사항** - 이 앱은 참고용이며, 실제 복약은 반드시 의사나 약사의 지시를 따르세요 - AI가 생성한 정보이므로 정확하지 않을 수 있습니다 **🤖 기술 스택** - PaddleOCR PP-OCRv5 (한국어 최적화 OCR) - Google Gemma-2-2B-IT (8bit 양자화, 빠른 의료 정보 분석) **🔑 설정 방법** - Hugging Face Spaces의 Settings → Repository secrets에서 `HF_TOKEN` 추가 필요 """) if __name__ == "__main__": demo.queue().launch()