import json import re from typing import List, Optional, Tuple import numpy as np import gradio as gr import spaces from PIL import Image from paddleocr import PaddleOCR # PaddleOCR 초기화 (한국어) print("🔄 Loading PaddleOCR (Korean)...") OCR_MODEL = PaddleOCR(use_angle_cls=True, lang='korean', use_gpu=True) print("✅ PaddleOCR loaded!") def _extract_assistant_content(decoded: str) -> str: """어시스턴트 응답 추출""" if "<|im_start|>assistant" in decoded: content = decoded.split("<|im_start|>assistant")[-1] content = content.replace("<|im_end|>", "").strip() return content return decoded.strip() def _extract_json_block(text: str) -> Optional[str]: """JSON 블록 추출""" match = re.search(r"\{.*\}", text, re.DOTALL) if not match: return None return match.group(0) def extract_text_from_image(image: Image.Image) -> str: """PaddleOCR로 이미지에서 텍스트 추출""" try: # PIL Image를 numpy array로 변환 img_array = np.array(image) # PaddleOCR 실행 result = OCR_MODEL.ocr(img_array, cls=True) # 결과에서 텍스트만 추출 if result and result[0]: texts = [line[1][0] for line in result[0]] extracted_text = "\n".join(texts) return extracted_text.strip() else: return "텍스트를 찾을 수 없습니다." except Exception as e: raise Exception(f"OCR 오류: {str(e)}") def extract_medications_from_text(text: str) -> List[str]: """Stage 2: Qwen2.5로 텍스트에서 약 이름만 추출""" try: messages = [ { "role": "system", "content": "You are a medical text analyzer. Extract only medication names from the given text and return them as a JSON array. Return ONLY valid JSON format." }, { "role": "user", "content": f"Extract all medication names from this text:\n\n{text}\n\nReturn format: {{\"medications\": [\"name1\", \"name2\"]}}" } ] prompt = LLM_TOKENIZER.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = LLM_TOKENIZER(prompt, return_tensors="pt").to(LLM_MODEL.device) with torch.no_grad(): outputs = LLM_MODEL.generate( **inputs, max_new_tokens=512, temperature=0.3, top_p=0.9, do_sample=True, pad_token_id=LLM_TOKENIZER.eos_token_id, ) response = LLM_TOKENIZER.decode(outputs[0], skip_special_tokens=True) # Extract assistant response (Qwen format) if "<|im_start|>assistant" in response: response = response.split("<|im_start|>assistant")[-1] response = response.replace("<|im_end|>", "").strip() # Parse JSON json_match = re.search(r'\{.*?\}', response, re.DOTALL) if json_match: data = json.loads(json_match.group(0)) medications = data.get("medications", []) if isinstance(medications, list) and medications: return [str(m).strip() for m in medications if str(m).strip()] return ["약 이름을 찾지 못했습니다."] except Exception as e: raise Exception(f"LLM 분석 오류: {str(e)}") @spaces.GPU(duration=120) def extract_medication_names(image: Image.Image) -> Tuple[str, List[str]]: """2단계 파이프라인: OCR → LLM 분석""" try: # Stage 1: OCR로 텍스트 추출 extracted_text = extract_text_from_image(image) if not extracted_text: return "", ["텍스트를 추출하지 못했습니다."] # Stage 2: LLM으로 약 이름 추출 medications = extract_medications_from_text(extracted_text) return extracted_text, medications except Exception as e: return "", [f"오류 발생: {str(e)}"] def format_results(extracted_text: str, medications: List[str]) -> Tuple[str, str]: """결과를 포맷팅""" # 추출된 전체 텍스트 text_output = f"### 📄 추출된 텍스트\n\n```\n{extracted_text}\n```" # 약 이름 리스트 if not medications or medications[0].startswith("오류") or medications[0].startswith("약 이름을 찾지") or medications[0].startswith("텍스트를"): med_output = f"### ⚠️ {medications[0] if medications else '약 이름을 찾지 못했습니다.'}" else: med_output = f"### 💊 검출된 약물 ({len(medications)}개)\n\n" for idx, med_name in enumerate(medications, 1): med_output += f"{idx}. **{med_name}**\n" return text_output, med_output def run_analysis(image: Optional[Image.Image], progress=gr.Progress()): """메인 분석 파이프라인: OCR만 실행""" if image is None: return "📷 약 봉투나 처방전 사진을 업로드해주세요." progress(0.5, desc="📸 OCR 텍스트 추출 중...") try: extracted_text = extract_text_from_image(image) progress(1.0, desc="✅ 완료!") return f"### 📄 OCR 추출 결과\n\n```\n{extracted_text}\n```" except Exception as e: return f"### ⚠️ 오류 발생\n\n{str(e)}" # 심플한 CSS CUSTOM_CSS = """ @import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); :root { --primary: #6366f1; --secondary: #8b5cf6; } body { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; } .gradio-container { max-width: 900px !important; margin: auto; background: rgba(255, 255, 255, 0.98); border-radius: 24px; box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.3); padding: 40px; } .hero { text-align: center; padding: 30px 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 20px; color: white; margin-bottom: 30px; } .hero h1 { font-size: 2.5rem; font-weight: 700; margin-bottom: 10px; } .hero p { font-size: 1.1rem; opacity: 0.95; } .upload-section { background: white; border-radius: 16px; padding: 30px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); margin-bottom: 20px; } .result-section { background: white; border-radius: 16px; padding: 30px; box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); min-height: 200px; } .analyze-btn button { background: linear-gradient(135deg, var(--primary), var(--secondary)) !important; color: white !important; font-weight: 600 !important; font-size: 1.1rem !important; padding: 18px 40px !important; border-radius: 12px !important; border: none !important; box-shadow: 0 10px 20px -5px rgba(99, 102, 241, 0.5) !important; transition: all 0.3s ease !important; } .analyze-btn button:hover { transform: translateY(-2px) !important; box-shadow: 0 15px 30px -5px rgba(99, 102, 241, 0.6) !important; } .gr-image { border-radius: 12px !important; } """ HERO_HTML = """

💊 약 이름 추출기

약봉투/처방전 사진에서 약 이름을 자동으로 추출합니다

""" # Gradio 인터페이스 with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: gr.HTML(HERO_HTML) with gr.Column(elem_classes=["upload-section"]): gr.Markdown("### 📸 사진 업로드") image_input = gr.Image(type="pil", label="약봉투 또는 처방전 사진", height=350) analyze_button = gr.Button("🔍 OCR 텍스트 추출", elem_classes=["analyze-btn"], size="lg") with gr.Column(elem_classes=["result-section"]): gr.Markdown("### 📋 OCR 추출 결과") text_output = gr.Markdown("OCR로 추출된 전체 텍스트가 여기 표시됩니다.") analyze_button.click( run_analysis, inputs=image_input, outputs=text_output, ) gr.Markdown(""" --- **ℹ️ OCR 모델** - PaddleOCR (Korean) - 한국어 텍스트 인식에 최적화된 OCR 엔진 """) if __name__ == "__main__": demo.queue().launch()