| import os |
| import re |
| import torch |
| import traceback |
| import gradio as gr |
| from transformers import DonutProcessor, VisionEncoderDecoderModel |
|
|
| |
| MODEL_NAME = "naver-clova-ix/donut-base-finetuned-cord-v2" |
| processor = DonutProcessor.from_pretrained(MODEL_NAME) |
| model = VisionEncoderDecoderModel.from_pretrained(MODEL_NAME) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| model.to(device) |
|
|
| |
| def ocr_donut(image): |
| try: |
| if image is None: |
| return {"error": "No image provided."} |
| task_prompt = "<s_cord-v2>" |
| decoder_input_ids = processor.tokenizer( |
| task_prompt, add_special_tokens=False, return_tensors="pt" |
| ).input_ids.to(device) |
| pixel_values = processor(image.convert("RGB"), return_tensors="pt").pixel_values.to(device) |
|
|
| outputs = model.generate( |
| pixel_values, |
| decoder_input_ids=decoder_input_ids, |
| max_length=model.config.decoder.max_position_embeddings, |
| pad_token_id=processor.tokenizer.pad_token_id, |
| eos_token_id=processor.tokenizer.eos_token_id, |
| use_cache=True, |
| bad_words_ids=[[processor.tokenizer.unk_token_id]], |
| return_dict_in_generate=True, |
| ) |
|
|
| seq = processor.batch_decode(outputs.sequences)[0] |
| seq = seq.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "") |
| seq = re.sub(r"<.*?>", "", seq, count=1).strip() |
| return {"result": processor.token2json(seq)} |
|
|
| except Exception: |
| tb = traceback.format_exc() |
| print(tb) |
| return {"error": tb} |
|
|
| |
| custom_css = """ |
| body { background: #f0f2f5; font-family: 'Segoe UI', Tahoma, sans-serif; } |
| .gradio-container { max-width: 900px; margin: 40px auto; padding: 20px; } |
| .header { text-align: center; margin-bottom: 30px; } |
| .header h1 { font-size: 2.8rem; color: #333; margin: 0; } |
| .header p { color: #666; margin-top: 8px; } |
| |
| .input-box, .output-box { |
| background: #fff; |
| border-radius: 8px; |
| box-shadow: 0 2px 8px rgba(0,0,0,0.1); |
| padding: 20px; |
| } |
| .input-box { margin-right: 10px; } |
| .output-box { margin-left: 10px; } |
| |
| .gr-button { |
| background: #5a8dee !important; |
| color: #fff !important; |
| border-radius: 6px !important; |
| padding: 10px 20px !important; |
| font-size: 1rem !important; |
| margin-top: 10px !important; |
| transition: background 0.2s ease; |
| } |
| .gr-button:hover { background: #3f6fcc !important; } |
| |
| .footer { |
| text-align: center; |
| margin-top: 30px; |
| color: #999; |
| font-size: 0.85rem; |
| } |
| """ |
|
|
| |
| with gr.Blocks(css=custom_css, title="Donut OCR App") as demo: |
| |
| gr.HTML( |
| """ |
| <div class="header"> |
| <h1>π Donut OCR</h1> |
| <p>Industrial AI Engineering Week 8 Assignment</p> |
| </div> |
| """ |
| ) |
|
|
| |
| with gr.Row(): |
| with gr.Column(elem_classes="input-box"): |
| image_input = gr.Image(type="pil", label="Upload Document Image") |
| run_btn = gr.Button("Run OCR", elem_id="run-btn") |
| with gr.Column(elem_classes="output-box"): |
| result_box = gr.JSON(label="Output") |
|
|
| |
| run_btn.click(fn=ocr_donut, inputs=image_input, outputs=result_box) |
|
|
| |
| gr.HTML( |
| """ |
| <div class="footer"> |
| <p>Powered by Naver Clova Donut</p> |
| </div> |
| """ |
| ) |
|
|
| |
| demo.launch( |
| server_name="0.0.0.0", |
| server_port=int(os.environ.get("PORT", 7860)), |
| debug=True |
| ) |
|
|