LLDDWW commited on
Commit
6d9d526
ยท
1 Parent(s): 72114b8

feat: add qwen explanations and refined ui

Browse files
Files changed (2) hide show
  1. app.py +157 -18
  2. requirements.txt +1 -1
app.py CHANGED
@@ -5,20 +5,34 @@ from typing import Any, Dict, List, Optional, Sequence
5
  import gradio as gr
6
  import torch
7
  from PIL import Image, ImageDraw
8
- from transformers import pipeline
9
 
10
  # --- OCR pipeline ---------------------------------------------------------
11
  # Use a high-capacity OCR model for better accuracy on prescription labels.
12
- MODEL_ID = "microsoft/trocr-large-printed"
 
13
 
14
 
15
  def _load_ocr():
16
  device = 0 if torch.cuda.is_available() else -1
17
- return pipeline("image-to-text", model=MODEL_ID, device=device)
18
 
19
 
20
  ocr = _load_ocr()
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  # Korean keywords describing time slots on prescription labels.
23
  TIME_KEYWORDS = [
24
  "์•„์นจ",
@@ -201,7 +215,7 @@ def _match_knowledge(name: str) -> Optional[Dict[str, Any]]:
201
  return None
202
 
203
 
204
- def build_explanations(output: Dict[str, Any]) -> str:
205
  meds = output["fields"].get("medications") or []
206
  if not meds:
207
  return (
@@ -236,30 +250,155 @@ def build_explanations(output: Dict[str, Any]) -> str:
236
  return "\n".join(lines)
237
 
238
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  def run_pipeline(image: Optional[Image.Image]):
240
  if image is None:
241
- return "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.", None, None, "์ด๋ฏธ์ง€๋ฅผ ๋จผ์ € ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”."
 
 
 
 
 
 
242
 
243
  output = ocr_and_parse(image)
244
  card = render_card(output["fields"])
245
  csv_row = to_csv_row(output)
246
  json_text = json.dumps(output, ensure_ascii=False, indent=2)
247
  explanations = build_explanations(output)
248
- return json_text, card, csv_row, explanations
249
-
250
-
251
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
252
- gr.Markdown("# MedCard-KR ยท ์•ฝ๋ด‰ํˆฌ OCR โ†’ ๋ณต์šฉ ์ผ์ • ์นด๋“œ")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  with gr.Row():
254
- with gr.Column():
255
- img_in = gr.Image(type="pil", label="์•ฝ ๋ด‰ํˆฌ/๋ผ๋ฒจ ์‚ฌ์ง„")
256
- btn = gr.Button("์ธ์‹ & ์นด๋“œ ์ƒ์„ฑ", variant="primary")
257
- csv_box = gr.Textbox(label="CSV(์•ฝ๋ช…,1ํšŒ์šฉ๋Ÿ‰,1์ผํšŸ์ˆ˜,์‹œ๊ฐ„๋Œ€)")
258
- with gr.Column():
259
- json_out = gr.Code(label="์ธ์‹ ๊ฒฐ๊ณผ(JSON)")
 
 
260
  card_out = gr.Image(type="pil", label="์ผ์ • ์นด๋“œ(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
261
- explain_md = gr.Markdown(label="์‰ฝ๊ฒŒ ์•Œ์•„๋ณด๋Š” ์•ฝ ์„ค๋ช…")
262
- btn.click(run_pipeline, inputs=img_in, outputs=[json_out, card_out, csv_box, explain_md])
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
 
265
  if __name__ == "__main__":
 
5
  import gradio as gr
6
  import torch
7
  from PIL import Image, ImageDraw
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
9
 
10
  # --- OCR pipeline ---------------------------------------------------------
11
  # Use a high-capacity OCR model for better accuracy on prescription labels.
12
+ OCR_MODEL_ID = "microsoft/trocr-large-printed"
13
+ LLM_MODEL_ID = "Qwen/Qwen2.5-0.5B-Instruct"
14
 
15
 
16
  def _load_ocr():
17
  device = 0 if torch.cuda.is_available() else -1
18
+ return pipeline("image-to-text", model=OCR_MODEL_ID, device=device)
19
 
20
 
21
  ocr = _load_ocr()
22
 
23
+
24
+ def _load_llm():
25
+ device_map = "auto" if torch.cuda.is_available() else None
26
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
27
+ model = AutoModelForCausalLM.from_pretrained(LLM_MODEL_ID, device_map=device_map, torch_dtype=dtype)
28
+ if device_map is None:
29
+ model = model.to(torch.device("cpu"))
30
+ tokenizer = AutoTokenizer.from_pretrained(LLM_MODEL_ID)
31
+ return model, tokenizer
32
+
33
+
34
+ LLM_MODEL, LLM_TOKENIZER = _load_llm()
35
+
36
  # Korean keywords describing time slots on prescription labels.
37
  TIME_KEYWORDS = [
38
  "์•„์นจ",
 
215
  return None
216
 
217
 
218
+ def build_kb_explanations(output: Dict[str, Any]) -> str:
219
  meds = output["fields"].get("medications") or []
220
  if not meds:
221
  return (
 
250
  return "\n".join(lines)
251
 
252
 
253
+ def generate_llm_explanations(output: Dict[str, Any]) -> str:
254
+ meds = output["fields"].get("medications") or []
255
+ if not meds:
256
+ return (
257
+ "์•ฝ ์ด๋ฆ„์„ ์ œ๋Œ€๋กœ ์ธ์‹ํ•˜์ง€ ๋ชปํ–ˆ์–ด์š”. ์‚ฌ์ง„์„ ๋‹ค์‹œ ์ฐ๊ฑฐ๋‚˜ ์•ฝ์‚ฌ์—๊ฒŒ ์ง์ ‘ ํ™•์ธํ•ด ์ฃผ์„ธ์š”."
258
+ )
259
+
260
+ med_lines = []
261
+ for idx, med in enumerate(meds, 1):
262
+ name = med.get("name") or "์ด๋ฆ„ ๋ฏธํ™•์ธ"
263
+ dose = med.get("dose") or "์šฉ๋Ÿ‰ ์ •๋ณด ์—†์Œ"
264
+ med_lines.append(f"{idx}. {name} โ€” {dose}")
265
+
266
+ context = "\n".join(med_lines)
267
+ raw_text = output.get("raw_text", "")
268
+
269
+ system_prompt = (
270
+ "๋‹น์‹ ์€ ์•ฝ์‚ฌ ์„ ์ƒ๋‹˜์ž…๋‹ˆ๋‹ค. ์–ด๋ ค์šด ์˜ํ•™ ์šฉ์–ด๋ฅผ ์“ฐ์ง€ ๋ง๊ณ , ์ค‘ํ•™์ƒ๋„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” ๋งํˆฌ๋กœ ์นœ์ ˆํ•˜๊ฒŒ ์„ค๋ช…ํ•˜์„ธ์š”."
271
+ )
272
+ user_prompt = (
273
+ "๋‹ค์Œ์€ ์•ฝ๋ด‰ํˆฌ OCR ๊ฒฐ๊ณผ์ž…๋‹ˆ๋‹ค. ์•ฝ ์ด๋ฆ„๊ณผ ์šฉ๋Ÿ‰ ์ •๋ณด๋ฅผ ์ฐธ๊ณ ํ•ด ๊ฐ ์•ฝ์˜ ์—ญํ• ์„ ์‰ฝ๊ฒŒ ์„ค๋ช…ํ•˜๊ณ , ์–ธ์ œ ๋ณต์šฉํ•˜๋ฉด ์ข‹์€์ง€ ์˜ˆ์‹œ, ์ฃผ์˜์‚ฌํ•ญ์„ bullet๋กœ ์ •๋ฆฌํ•ด ์ฃผ์„ธ์š”.\n"
274
+ f"์•ฝ ๋ชฉ๋ก:\n{context}\n\nOCR ์›๋ฌธ:\n{raw_text}\n\n์ถœ๋ ฅ ํ˜•์‹:\n- ์•ฝ ์ด๋ฆ„: ...\n - ํ•œ ์ค„ ์„ค๋ช…\n - ์˜ˆ์‹œ ์ƒํ™ฉ\n - ์ฃผ์˜ํ•  ์ \n๋งˆ์ง€๋ง‰์—๋Š” ์˜๋ฃŒ์ง„ ๋ณต์•ฝ ์ง€์‹œ๋ฅผ ๋ฐ˜๋“œ์‹œ ๋”ฐ๋ผ์•ผ ํ•œ๋‹ค๋Š” ๋ฌธ์žฅ์„ ๋ง๋ถ™์—ฌ ์ฃผ์„ธ์š”."
275
+ )
276
+
277
+ messages = [
278
+ {"role": "system", "content": system_prompt},
279
+ {"role": "user", "content": user_prompt},
280
+ ]
281
+
282
+ input_ids = LLM_TOKENIZER.apply_chat_template(
283
+ messages,
284
+ add_generation_prompt=True,
285
+ return_tensors="pt",
286
+ )
287
+ input_ids = input_ids.to(LLM_MODEL.device)
288
+
289
+ with torch.no_grad():
290
+ output_ids = LLM_MODEL.generate(
291
+ input_ids,
292
+ max_new_tokens=480,
293
+ temperature=0.7,
294
+ top_p=0.9,
295
+ do_sample=True,
296
+ eos_token_id=LLM_TOKENIZER.eos_token_id,
297
+ )
298
+
299
+ generated_ids = output_ids[0][input_ids.shape[1]:]
300
+ text = LLM_TOKENIZER.decode(generated_ids, skip_special_tokens=True).strip()
301
+ return text
302
+
303
+
304
+ def build_explanations(output: Dict[str, Any]) -> str:
305
+ try:
306
+ llm_text = generate_llm_explanations(output)
307
+ if llm_text:
308
+ return llm_text
309
+ except Exception as err: # pragma: no cover - safe fallback
310
+ print(f"[WARN] LLM generation failed: {err}", flush=True)
311
+ return build_kb_explanations(output)
312
+
313
+
314
+ def format_warnings(warnings: List[str]) -> str:
315
+ if not warnings:
316
+ return "โœ… ์ธ์‹๋œ ์ •๋ณด๊ฐ€ ์ถฉ๋ถ„ํ•ด์š”. ๋ณต์•ฝ ์‹œ๊ฐ„๋งŒ ์ž˜ ์ง€์ผœ ์ฃผ์„ธ์š”."
317
+ lines = ["### ํ™•์ธํ•ด ์ฃผ์„ธ์š”"]
318
+ for warn in warnings:
319
+ lines.append(f"- {warn}")
320
+ lines.append("\n> ์˜๋ฃŒ์ง„์˜ ์ง€์‹œ๊ฐ€ ๊ฐ€์žฅ ์ •ํ™•ํ•ฉ๋‹ˆ๋‹ค.")
321
+ return "\n".join(lines)
322
+
323
+
324
  def run_pipeline(image: Optional[Image.Image]):
325
  if image is None:
326
+ return (
327
+ "์ด๋ฏธ์ง€๋ฅผ ์—…๋กœ๋“œํ•˜์„ธ์š”.",
328
+ None,
329
+ None,
330
+ "์ด๋ฏธ์ง€๋ฅผ ๋จผ์ € ์—…๋กœ๋“œํ•ด ์ฃผ์„ธ์š”.",
331
+ "๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ ์‚ฌ์ง„์„ ์˜ฌ๋ฆฌ๋ฉด ์ธ์‹์ด ์‹œ์ž‘๋ผ์š”.",
332
+ )
333
 
334
  output = ocr_and_parse(image)
335
  card = render_card(output["fields"])
336
  csv_row = to_csv_row(output)
337
  json_text = json.dumps(output, ensure_ascii=False, indent=2)
338
  explanations = build_explanations(output)
339
+ warnings_md = format_warnings(output.get("warnings", []))
340
+ return json_text, card, csv_row, explanations, warnings_md
341
+
342
+
343
+ CUSTOM_CSS = """
344
+ body {background: radial-gradient(circle at top left, #f5f0ff 0%, #fff7ec 60%, #ffffff 100%);}
345
+ .gradio-container {max-width: 1180px !important; margin: auto; font-family: 'Noto Sans KR', sans-serif;}
346
+ .hero {
347
+ background: linear-gradient(120deg, rgba(123, 97, 255, 0.12), rgba(255, 207, 117, 0.18));
348
+ border-radius: 28px;
349
+ padding: 36px 44px;
350
+ box-shadow: 0 20px 40px rgba(66, 46, 138, 0.08);
351
+ margin-bottom: 32px;
352
+ }
353
+ .hero h1 {font-size: 2.4rem; font-weight: 700; color: #1f1c3b; margin-bottom: 12px;}
354
+ .hero p {color: #514c7b; font-size: 1.05rem; line-height: 1.6; max-width: 640px;}
355
+ .glass-panel {background: rgba(255, 255, 255, 0.72); backdrop-filter: blur(18px); border-radius: 26px; padding: 28px; box-shadow: 0 12px 32px rgba(80, 60, 160, 0.12);}
356
+ .panel-title {font-weight: 700; font-size: 1.2rem; margin-bottom: 18px; color: #2f2355;}
357
+ .primary-btn button {background: linear-gradient(120deg, #7c62ff, #ffa74d); border: none; color: white; font-weight: 600; border-radius: 999px; padding: 12px 22px; box-shadow: 0 12px 24px rgba(124, 98, 255, 0.25);}
358
+ .primary-btn button:hover {opacity: 0.95; transform: translateY(-1px);}
359
+ .output-card {background: rgba(255, 255, 255, 0.88); border-radius: 22px; padding: 24px; box-shadow: inset 0 0 0 1px rgba(124, 98, 255, 0.08), 0 14px 30px rgba(49, 32, 114, 0.12);}
360
+ .notice {background: rgba(255, 247, 226, 0.9); border-radius: 18px; padding: 18px; color: #7a4b00; box-shadow: inset 0 0 0 1px rgba(255, 193, 96, 0.3);}
361
+ .csv-box textarea {font-family: 'JetBrains Mono', monospace;}
362
+ .gr-image {border-radius: 20px !important; box-shadow: 0 10px 20px rgba(60, 40, 120, 0.15);}
363
+ .accordion {border-radius: 20px !important;}
364
+ """
365
+
366
+ HERO_HTML = """
367
+ <div class="hero">
368
+ <h1>MedCard-KR ยท ์•ฝ๋ด‰ํˆฌ ํ•œ ์ปท์œผ๋กœ ์ดํ•ดํ•˜๋Š” ๋ณต์šฉ ์•ˆ๋‚ด</h1>
369
+ <p>์‚ฌ์ง„ ์† ์•ฝ ์ด๋ฆ„์„ OCR๋กœ ์ฝ์–ด ๋“ค์ด๊ณ , Qwen LLM์ด ์ค‘ํ•™์ƒ๋„ ์ดํ•ดํ•  ์ˆ˜ ์žˆ๋Š” ๋งํˆฌ๋กœ ์•ฝ์„ ์„ค๋ช…ํ•ด ๋“œ๋ฆฝ๋‹ˆ๋‹ค.
370
+ ๋ณต์šฉ ์ผ์ • ์นด๋“œ์™€ CSV๊นŒ์ง€ ํ•œ ๋ฒˆ์— ๋ฐ›์•„ ๋ณด์„ธ์š”.</p>
371
+ </div>
372
+ """
373
+
374
+
375
+ with gr.Blocks(theme=gr.themes.Soft(), css=CUSTOM_CSS) as demo:
376
+ gr.HTML(HERO_HTML)
377
  with gr.Row():
378
+ with gr.Column(scale=4, elem_classes=["glass-panel"]):
379
+ gr.Markdown("### 1. ์•ฝ ๋ด‰ํˆฌ ์‚ฌ์ง„์„ ์—…๋กœ๋“œํ•˜์„ธ์š”")
380
+ img_in = gr.Image(type="pil", label="์•ฝ ๋ด‰ํˆฌ/๋ผ๋ฒจ ์‚ฌ์ง„", height=360)
381
+ warn_md = gr.Markdown("๐Ÿ“ท ์•ฝ ๋ด‰ํˆฌ ์‚ฌ์ง„์„ ์˜ฌ๋ฆฌ๋ฉด ์ธ์‹์ด ์‹œ์ž‘๋ผ์š”.", elem_classes=["notice"])
382
+ btn = gr.Button("์ธ์‹ & ์„ค๋ช… ์ƒ์„ฑ", elem_classes=["primary-btn"])
383
+ with gr.Column(scale=6, elem_classes=["glass-panel"]):
384
+ gr.Markdown("### 2. ๊ฒฐ๊ณผ๋ฅผ ํ™•์ธํ•˜์„ธ์š”")
385
+ explain_md = gr.Markdown("์—ฌ๊ธฐ์— ์•ฝ ์„ค๋ช…์ด ํ‘œ์‹œ๋ฉ๋‹ˆ๋‹ค.", elem_classes=["output-card"])
386
  card_out = gr.Image(type="pil", label="์ผ์ • ์นด๋“œ(๋ฏธ๋ฆฌ๋ณด๊ธฐ)")
387
+ csv_box = gr.Textbox(label="CSV(์•ฝ๋ช…,1ํšŒ์šฉ๋Ÿ‰,1์ผํšŸ์ˆ˜,์‹œ๊ฐ„๋Œ€)", lines=2, elem_classes=["csv-box"])
388
+ with gr.Accordion("์„ธ๋ถ€ JSON ๊ฒฐ๊ณผ", open=False, elem_classes=["accordion"]):
389
+ json_out = gr.Code(label="์ธ์‹ ๊ฒฐ๊ณผ(JSON)")
390
+
391
+ btn.click(
392
+ run_pipeline,
393
+ inputs=img_in,
394
+ outputs=[json_out, card_out, csv_box, explain_md, warn_md],
395
+ )
396
+
397
+ gr.Markdown(
398
+ """
399
+ > โ„น๏ธ **์ฃผ์˜**: ์ด ์„œ๋น„์Šค๋Š” ์ฐธ๊ณ ์šฉ ๋„๊ตฌ์ด๋ฉฐ, ์‹ค์ œ ๋ณต์•ฝ์€ ๋ฐ˜๋“œ์‹œ ์˜์‚ฌยท์•ฝ์‚ฌ์˜ ์ง€์‹œ์— ๋”ฐ๋ผ ์ฃผ์„ธ์š”.
400
+ """
401
+ )
402
 
403
 
404
  if __name__ == "__main__":
requirements.txt CHANGED
@@ -2,4 +2,4 @@ transformers
2
  torch
3
  gradio
4
  Pillow
5
- torch
 
2
  torch
3
  gradio
4
  Pillow
5
+ sentencepiece