|
|
import json |
|
|
import re |
|
|
from typing import List, Optional, Tuple |
|
|
import numpy as np |
|
|
|
|
|
import gradio as gr |
|
|
import spaces |
|
|
from PIL import Image |
|
|
from paddleocr import PaddleOCR |
|
|
|
|
|
|
|
|
print("๐ Loading PaddleOCR (Korean)...") |
|
|
OCR_MODEL = PaddleOCR(use_angle_cls=True, lang='korean', use_gpu=True) |
|
|
print("โ
PaddleOCR loaded!") |
|
|
|
|
|
|
|
|
def _extract_assistant_content(decoded: str) -> str: |
|
|
"""์ด์์คํดํธ ์๋ต ์ถ์ถ""" |
|
|
if "<|im_start|>assistant" in decoded: |
|
|
content = decoded.split("<|im_start|>assistant")[-1] |
|
|
content = content.replace("<|im_end|>", "").strip() |
|
|
return content |
|
|
return decoded.strip() |
|
|
|
|
|
|
|
|
def _extract_json_block(text: str) -> Optional[str]: |
|
|
"""JSON ๋ธ๋ก ์ถ์ถ""" |
|
|
match = re.search(r"\{.*\}", text, re.DOTALL) |
|
|
if not match: |
|
|
return None |
|
|
return match.group(0) |
|
|
|
|
|
|
|
|
def extract_text_from_image(image: Image.Image) -> str: |
|
|
"""PaddleOCR๋ก ์ด๋ฏธ์ง์์ ํ
์คํธ ์ถ์ถ""" |
|
|
try: |
|
|
|
|
|
img_array = np.array(image) |
|
|
|
|
|
|
|
|
result = OCR_MODEL.ocr(img_array, cls=True) |
|
|
|
|
|
|
|
|
if result and result[0]: |
|
|
texts = [line[1][0] for line in result[0]] |
|
|
extracted_text = "\n".join(texts) |
|
|
return extracted_text.strip() |
|
|
else: |
|
|
return "ํ
์คํธ๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค." |
|
|
|
|
|
except Exception as e: |
|
|
raise Exception(f"OCR ์ค๋ฅ: {str(e)}") |
|
|
|
|
|
|
|
|
def extract_medications_from_text(text: str) -> List[str]: |
|
|
"""Stage 2: Qwen2.5๋ก ํ
์คํธ์์ ์ฝ ์ด๋ฆ๋ง ์ถ์ถ""" |
|
|
try: |
|
|
messages = [ |
|
|
{ |
|
|
"role": "system", |
|
|
"content": "You are a medical text analyzer. Extract only medication names from the given text and return them as a JSON array. Return ONLY valid JSON format." |
|
|
}, |
|
|
{ |
|
|
"role": "user", |
|
|
"content": f"Extract all medication names from this text:\n\n{text}\n\nReturn format: {{\"medications\": [\"name1\", \"name2\"]}}" |
|
|
} |
|
|
] |
|
|
|
|
|
prompt = LLM_TOKENIZER.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True |
|
|
) |
|
|
|
|
|
inputs = LLM_TOKENIZER(prompt, return_tensors="pt").to(LLM_MODEL.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = LLM_MODEL.generate( |
|
|
**inputs, |
|
|
max_new_tokens=512, |
|
|
temperature=0.3, |
|
|
top_p=0.9, |
|
|
do_sample=True, |
|
|
pad_token_id=LLM_TOKENIZER.eos_token_id, |
|
|
) |
|
|
|
|
|
response = LLM_TOKENIZER.decode(outputs[0], skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if "<|im_start|>assistant" in response: |
|
|
response = response.split("<|im_start|>assistant")[-1] |
|
|
response = response.replace("<|im_end|>", "").strip() |
|
|
|
|
|
|
|
|
json_match = re.search(r'\{.*?\}', response, re.DOTALL) |
|
|
if json_match: |
|
|
data = json.loads(json_match.group(0)) |
|
|
medications = data.get("medications", []) |
|
|
if isinstance(medications, list) and medications: |
|
|
return [str(m).strip() for m in medications if str(m).strip()] |
|
|
|
|
|
return ["์ฝ ์ด๋ฆ์ ์ฐพ์ง ๋ชปํ์ต๋๋ค."] |
|
|
|
|
|
except Exception as e: |
|
|
raise Exception(f"LLM ๋ถ์ ์ค๋ฅ: {str(e)}") |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def extract_medication_names(image: Image.Image) -> Tuple[str, List[str]]: |
|
|
"""2๋จ๊ณ ํ์ดํ๋ผ์ธ: OCR โ LLM ๋ถ์""" |
|
|
try: |
|
|
|
|
|
extracted_text = extract_text_from_image(image) |
|
|
|
|
|
if not extracted_text: |
|
|
return "", ["ํ
์คํธ๋ฅผ ์ถ์ถํ์ง ๋ชปํ์ต๋๋ค."] |
|
|
|
|
|
|
|
|
medications = extract_medications_from_text(extracted_text) |
|
|
|
|
|
return extracted_text, medications |
|
|
|
|
|
except Exception as e: |
|
|
return "", [f"์ค๋ฅ ๋ฐ์: {str(e)}"] |
|
|
|
|
|
|
|
|
def format_results(extracted_text: str, medications: List[str]) -> Tuple[str, str]: |
|
|
"""๊ฒฐ๊ณผ๋ฅผ ํฌ๋งทํ
""" |
|
|
|
|
|
text_output = f"### ๐ ์ถ์ถ๋ ํ
์คํธ\n\n```\n{extracted_text}\n```" |
|
|
|
|
|
|
|
|
if not medications or medications[0].startswith("์ค๋ฅ") or medications[0].startswith("์ฝ ์ด๋ฆ์ ์ฐพ์ง") or medications[0].startswith("ํ
์คํธ๋ฅผ"): |
|
|
med_output = f"### โ ๏ธ {medications[0] if medications else '์ฝ ์ด๋ฆ์ ์ฐพ์ง ๋ชปํ์ต๋๋ค.'}" |
|
|
else: |
|
|
med_output = f"### ๐ ๊ฒ์ถ๋ ์ฝ๋ฌผ ({len(medications)}๊ฐ)\n\n" |
|
|
for idx, med_name in enumerate(medications, 1): |
|
|
med_output += f"{idx}. **{med_name}**\n" |
|
|
|
|
|
return text_output, med_output |
|
|
|
|
|
|
|
|
def run_analysis(image: Optional[Image.Image], progress=gr.Progress()): |
|
|
"""๋ฉ์ธ ๋ถ์ ํ์ดํ๋ผ์ธ: OCR๋ง ์คํ""" |
|
|
if image is None: |
|
|
return "๐ท ์ฝ ๋ดํฌ๋ ์ฒ๋ฐฉ์ ์ฌ์ง์ ์
๋ก๋ํด์ฃผ์ธ์." |
|
|
|
|
|
progress(0.5, desc="๐ธ OCR ํ
์คํธ ์ถ์ถ ์ค...") |
|
|
|
|
|
try: |
|
|
extracted_text = extract_text_from_image(image) |
|
|
progress(1.0, desc="โ
์๋ฃ!") |
|
|
return f"### ๐ OCR ์ถ์ถ ๊ฒฐ๊ณผ\n\n```\n{extracted_text}\n```" |
|
|
except Exception as e: |
|
|
return f"### โ ๏ธ ์ค๋ฅ ๋ฐ์\n\n{str(e)}" |
|
|
|
|
|
|
|
|
|
|
|
CUSTOM_CSS = """ |
|
|
@import url('https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600;700&display=swap'); |
|
|
|
|
|
:root { |
|
|
--primary: #6366f1; |
|
|
--secondary: #8b5cf6; |
|
|
} |
|
|
|
|
|
body { |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif; |
|
|
} |
|
|
|
|
|
.gradio-container { |
|
|
max-width: 900px !important; |
|
|
margin: auto; |
|
|
background: rgba(255, 255, 255, 0.98); |
|
|
border-radius: 24px; |
|
|
box-shadow: 0 25px 50px -12px rgba(0, 0, 0, 0.3); |
|
|
padding: 40px; |
|
|
} |
|
|
|
|
|
.hero { |
|
|
text-align: center; |
|
|
padding: 30px 20px; |
|
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); |
|
|
border-radius: 20px; |
|
|
color: white; |
|
|
margin-bottom: 30px; |
|
|
} |
|
|
|
|
|
.hero h1 { |
|
|
font-size: 2.5rem; |
|
|
font-weight: 700; |
|
|
margin-bottom: 10px; |
|
|
} |
|
|
|
|
|
.hero p { |
|
|
font-size: 1.1rem; |
|
|
opacity: 0.95; |
|
|
} |
|
|
|
|
|
.upload-section { |
|
|
background: white; |
|
|
border-radius: 16px; |
|
|
padding: 30px; |
|
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); |
|
|
margin-bottom: 20px; |
|
|
} |
|
|
|
|
|
.result-section { |
|
|
background: white; |
|
|
border-radius: 16px; |
|
|
padding: 30px; |
|
|
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.07); |
|
|
min-height: 200px; |
|
|
} |
|
|
|
|
|
.analyze-btn button { |
|
|
background: linear-gradient(135deg, var(--primary), var(--secondary)) !important; |
|
|
color: white !important; |
|
|
font-weight: 600 !important; |
|
|
font-size: 1.1rem !important; |
|
|
padding: 18px 40px !important; |
|
|
border-radius: 12px !important; |
|
|
border: none !important; |
|
|
box-shadow: 0 10px 20px -5px rgba(99, 102, 241, 0.5) !important; |
|
|
transition: all 0.3s ease !important; |
|
|
} |
|
|
|
|
|
.analyze-btn button:hover { |
|
|
transform: translateY(-2px) !important; |
|
|
box-shadow: 0 15px 30px -5px rgba(99, 102, 241, 0.6) !important; |
|
|
} |
|
|
|
|
|
.gr-image { |
|
|
border-radius: 12px !important; |
|
|
} |
|
|
""" |
|
|
|
|
|
HERO_HTML = """ |
|
|
<div class="hero"> |
|
|
<h1>๐ ์ฝ ์ด๋ฆ ์ถ์ถ๊ธฐ</h1> |
|
|
<p>์ฝ๋ดํฌ/์ฒ๋ฐฉ์ ์ฌ์ง์์ ์ฝ ์ด๋ฆ์ ์๋์ผ๋ก ์ถ์ถํฉ๋๋ค</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo: |
|
|
gr.HTML(HERO_HTML) |
|
|
|
|
|
with gr.Column(elem_classes=["upload-section"]): |
|
|
gr.Markdown("### ๐ธ ์ฌ์ง ์
๋ก๋") |
|
|
image_input = gr.Image(type="pil", label="์ฝ๋ดํฌ ๋๋ ์ฒ๋ฐฉ์ ์ฌ์ง", height=350) |
|
|
analyze_button = gr.Button("๐ OCR ํ
์คํธ ์ถ์ถ", elem_classes=["analyze-btn"], size="lg") |
|
|
|
|
|
with gr.Column(elem_classes=["result-section"]): |
|
|
gr.Markdown("### ๐ OCR ์ถ์ถ ๊ฒฐ๊ณผ") |
|
|
text_output = gr.Markdown("OCR๋ก ์ถ์ถ๋ ์ ์ฒด ํ
์คํธ๊ฐ ์ฌ๊ธฐ ํ์๋ฉ๋๋ค.") |
|
|
|
|
|
analyze_button.click( |
|
|
run_analysis, |
|
|
inputs=image_input, |
|
|
outputs=text_output, |
|
|
) |
|
|
|
|
|
gr.Markdown(""" |
|
|
--- |
|
|
|
|
|
**โน๏ธ OCR ๋ชจ๋ธ** |
|
|
- PaddleOCR (Korean) - ํ๊ตญ์ด ํ
์คํธ ์ธ์์ ์ต์ ํ๋ OCR ์์ง |
|
|
""") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |
|
|
|