Spaces:
Sleeping
Sleeping
| import io | |
| import os | |
| import base64 | |
| import json | |
| import logging | |
| import unicodedata | |
| import tempfile | |
| from difflib import SequenceMatcher | |
| from PIL import Image, ImageDraw, ImageFont, ImageOps | |
| import cv2 | |
| import numpy as np | |
| import gradio as gr | |
| from google.cloud import vision | |
| from google.oauth2 import service_account | |
| from kospellpy import spell_init | |
| # ──────────────────────────────── 환경 설정 ──────────────────────────────── | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s %(levelname)s: %(message)s') | |
| FONT_PATH = os.path.join(os.path.dirname(__file__), "NanumGothicCoding.ttf") | |
| MIN_FONT_SIZE = 8 | |
| def get_vision_client(): | |
| b64 = os.getenv("GCP_SERVICE_ACCOUNT_JSON") | |
| if not b64: | |
| logging.warning("GCP_SERVICE_ACCOUNT_JSON 환경변수가 설정되지 않았습니다. 기본 인증을 사용합니다.") | |
| return vision.ImageAnnotatorClient() | |
| try: | |
| info = json.loads(base64.b64decode(b64).decode()) | |
| creds = service_account.Credentials.from_service_account_info(info) | |
| return vision.ImageAnnotatorClient(credentials=creds) | |
| except Exception as e: | |
| logging.error(f"Vision API 인증 실패: {e}") | |
| raise | |
| vision_client = get_vision_client() | |
| checker = spell_init() | |
| # ──────────────────────────────── KoSpellPy 긴 텍스트 안전 처리 ──────────────────────────────── | |
| def chunk_text(text, max_len=500): | |
| return [text[i:i+max_len] for i in range(0, len(text), max_len)] | |
| def safe_kospell_check(text): | |
| parts = chunk_text(text) | |
| corrected = [] | |
| for part in parts: | |
| try: | |
| corrected.append(checker(part)) | |
| except Exception as e: | |
| logging.warning(f"[Spell] 일부 텍스트 교정 실패: {e}") | |
| corrected.append(part) # 오류 발생 시 해당 부분은 원문 사용 | |
| return ' '.join(corrected) | |
| def normalize_text(text: str) -> str: | |
| return unicodedata.normalize('NFC', text) | |
| def compute_font_for_word(vertices): | |
| ys = [v.y for v in vertices] | |
| bbox_h = max(ys) - min(ys) | |
| size = max(MIN_FONT_SIZE, int(bbox_h * 0.4)) | |
| try: | |
| return ImageFont.truetype(FONT_PATH, size) | |
| except Exception as e: | |
| print(f"[WARNING] 폰트 로딩 실패: {e}") | |
| return ImageFont.load_default() | |
| def preprocess_with_adaptive_threshold(img: Image.Image) -> Image.Image: | |
| cv_img = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR) | |
| gray = cv2.cvtColor(cv_img, cv2.COLOR_BGR2GRAY) | |
| adap = cv2.adaptiveThreshold(gray, 255, cv2.ADAPTIVE_THRESH_GAUSSIAN_C, cv2.THRESH_BINARY, 25, 10) | |
| bgr = cv2.cvtColor(adap, cv2.COLOR_GRAY2BGR) | |
| return Image.fromarray(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| def ocr_overlay_and_correct_text(img: Image.Image): | |
| corrected_text = "" | |
| overlay = None | |
| if img is not None: | |
| img = ImageOps.exif_transpose(img) | |
| proc = preprocess_with_adaptive_threshold(img) | |
| buf = io.BytesIO(); proc.save(buf, format='PNG') | |
| res = vision_client.document_text_detection( | |
| image=vision.Image(content=buf.getvalue()), | |
| image_context={'language_hints': ['ko']} | |
| ) | |
| ann = res.full_text_annotation | |
| raw = ann.text.replace('\n', ' ').strip() | |
| logging.info(f"[OCR] Raw length: {len(raw)} / Raw: {raw}") | |
| try: | |
| corrected_text = safe_kospell_check(raw) | |
| logging.info(f"[Spell] Corrected: {corrected_text}") | |
| except Exception as e: | |
| logging.error(f"[Spell] 교정 중 오류 발생: {e}") | |
| corrected_text = raw # 오류 시 원문 반환 | |
| syms = [] | |
| for pg in ann.pages: | |
| for bl in pg.blocks: | |
| for para in bl.paragraphs: | |
| for w in para.words: | |
| for s in w.symbols: | |
| syms.append({'text': normalize_text(s.text), 'bbox': s.bounding_box.vertices}) | |
| raw_c, corr_c, mapping = list(raw), list(corrected_text), {} | |
| idx = 0 | |
| for i, ch in enumerate(raw_c): | |
| if ch.strip(): | |
| mapping[i] = idx | |
| idx += 1 | |
| sm = SequenceMatcher(None, raw_c, corr_c) | |
| overlay = img.copy() | |
| draw = ImageDraw.Draw(overlay) | |
| col = "#FF3333" | |
| for tag, i1, i2, j1, j2 in sm.get_opcodes(): | |
| if tag not in ('replace', 'insert'): | |
| continue | |
| repl = ''.join(corr_c[j1:j2]) | |
| if tag == 'insert' and repl == ' ': | |
| repl = 'V' | |
| valid = ( | |
| [k for k in range(i1, i2) if k in mapping] | |
| if tag == 'replace' | |
| else ([max(i1-1, 0)] if max(i1-1, 0) in mapping else []) | |
| ) | |
| for k in valid: | |
| sd = mapping[k] | |
| verts = syms[sd]['bbox'] | |
| xs, ys = [v.x for v in verts], [v.y for v in verts] | |
| x0, x1, y0, y1 = min(xs), max(xs), min(ys), max(ys) | |
| ul = y0 + int((y1 - y0) * 0.9) | |
| draw.line([(x0, ul), (x1, ul)], fill=col, width=3) | |
| if valid: | |
| sd = mapping[valid[0]] | |
| verts = syms[sd]['bbox'] | |
| xs, ys = [v.x for v in verts], [v.y for v in verts] | |
| x0, x1, y0 = min(xs), max(xs), min(ys) | |
| if tag == 'insert' and len(repl) == 1 and not repl.isalnum(): | |
| prev_k = max(i1 - 1, 0) | |
| if prev_k in mapping: | |
| prev_sd = mapping[prev_k] | |
| prev_verts = syms[prev_sd]['bbox'] | |
| prev_xs = [v.x for v in prev_verts] | |
| fx = max(prev_xs + xs) | |
| overlay_str = raw_c[prev_k] + repl | |
| else: | |
| overlay_str, fx = repl, x1 | |
| elif repl == 'V': | |
| overlay_str, fx = 'V', x1 | |
| elif not repl.isalnum(): | |
| overlay_str, fx = repl, x1 | |
| else: | |
| overlay_str, fx = repl, x0 | |
| fy = y0 | |
| font = compute_font_for_word(verts) | |
| draw.text((fx, fy), overlay_str, font=font, fill=col) | |
| return overlay, corrected_text | |
| def text_correct_fn(text): | |
| raw = normalize_text(text.strip()) | |
| try: | |
| corrected = safe_kospell_check(raw) | |
| except Exception as e: | |
| logging.error(f"[Spell/TextInput] 교정 중 오류 발생: {e}") | |
| corrected = raw | |
| return None, corrected | |
| def img_correct_fn(blob): | |
| img = None | |
| if blob: | |
| img = Image.open(io.BytesIO(blob)).convert('RGB') | |
| return ocr_overlay_and_correct_text(img) | |
| with gr.Blocks( | |
| css=""" | |
| .gradio-container {background-color: #fafaf5} | |
| footer {display: none !important;} | |
| .gr-box {border: 2px solid black !important;} | |
| * { font-family: 'Quicksand', ui-sans-serif, sans-serif !important; } | |
| """, | |
| theme="dark" | |
| ) as demo: | |
| state = gr.State() | |
| gr.Markdown("## 📷찰칵! 맞춤법 검사기") | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload = gr.UploadButton(label='사진 촬영 및 업로드', file_types=['image'], type='binary') | |
| img_check_btn = gr.Button('✔️검사하기', interactive=False) | |
| with gr.Column(): | |
| text_in = gr.Textbox(lines=3, placeholder='텍스트를 직접 입력하세요 (선택)', label='💻직접 입력 텍스트') | |
| text_check_btn = gr.Button('텍스트 검사', interactive=False) | |
| img_out = gr.Image(type='pil', label='교정 결과') | |
| txt_out = gr.Textbox(label='교정된 텍스트') | |
| clear_btn = gr.Button('초기화') | |
| def on_upload_start(): | |
| return gr.update(label="업로드 중...", interactive=False), gr.update(interactive=False) | |
| upload.upload(on_upload_start, None, [upload, img_check_btn], queue=False, preprocess=False) | |
| def on_upload_complete(blob): | |
| return blob, gr.update(label="업로드 완료", interactive=False), gr.update(interactive=True) | |
| upload.upload(on_upload_complete, inputs=[upload], outputs=[state, upload, img_check_btn]) | |
| def on_img_check(blob): | |
| result = img_correct_fn(blob) | |
| return gr.update(label="사진 촬영 및 업로드", interactive=True, value=None), gr.update(interactive=False), result[0], result[1] | |
| img_check_btn.click(on_img_check, inputs=[state], outputs=[upload, img_check_btn, img_out, txt_out]) | |
| def enable_text_check(text): | |
| return gr.update(interactive=bool(text.strip())) | |
| text_in.change(enable_text_check, inputs=[text_in], outputs=[text_check_btn]) | |
| text_check_btn.click(text_correct_fn, inputs=[text_in], outputs=[img_out, txt_out]) | |
| def on_clear(): | |
| return None, gr.update(label="사진 촬영 및 업로드", interactive=True, value=None), '', gr.update(interactive=False), None, '' | |
| clear_btn.click(on_clear, None, [state, upload, text_in, img_check_btn, img_out, txt_out]) | |
| if __name__ == '__main__': | |
| demo.launch(server_name="0.0.0.0", server_port=int(os.getenv("PORT", 7860))) | |