handwriting-ocr / app.py
fumiyaaa's picture
Upload app.py
48db211 verified
import gradio as gr
import torch
from transformers import Qwen3VLForConditionalGeneration, AutoProcessor
from PIL import Image
import numpy as np
import base64
from io import BytesIO
# モデルとプロセッサの読み込み
model_name = "Qwen/Qwen3-VL-4B-Instruct"
print(f"Loading model: {model_name}")
print(f"CUDA available: {torch.cuda.is_available()}")
# デバイスとdtypeの設定
if torch.cuda.is_available():
device = "cuda"
dtype = torch.bfloat16
else:
device = "cpu"
dtype = torch.float32
model = Qwen3VLForConditionalGeneration.from_pretrained(
model_name,
torch_dtype=dtype,
device_map="auto" if torch.cuda.is_available() else None,
)
processor = AutoProcessor.from_pretrained(model_name)
# CPUの場合はモデルを明示的に移動
if not torch.cuda.is_available():
model = model.to(device)
print("Model loaded successfully!")
def transcribe_handwriting(image):
"""手書き文字画像をOCRで文字起こしする"""
if image is None:
return "画像をアップロードしてください。"
# numpy配列の場合はPIL Imageに変換
if isinstance(image, np.ndarray):
if len(image.shape) == 2:
image = Image.fromarray(image).convert('RGB')
else:
image = Image.fromarray(image)
# PIL Imageの場合
if isinstance(image, Image.Image):
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (255, 255, 255))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode == 'L':
image = image.convert('RGB')
elif image.mode != 'RGB':
image = image.convert('RGB')
# メッセージの構築
messages = [
{
"role": "user",
"content": [
{
"type": "image",
"image": image,
},
{
"type": "text",
"text": "この画像に書かれている手書きの文字を正確に読み取って、テキストとして出力してください。日本語とアルファベットの両方に対応してください。文字以外の説明は不要です。読み取った文字のみを出力してください。",
},
],
}
]
# 入力の準備(Qwen3-VL用)
inputs = processor.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_dict=True,
return_tensors="pt",
)
inputs.pop("token_type_ids", None)
inputs = inputs.to(model.device)
# 推論
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=512,
do_sample=False,
)
generated_ids_trimmed = [
out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids)
]
output_text = processor.batch_decode(
generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
)
return output_text[0] if output_text else "文字を認識できませんでした。"
def process_canvas(base64_data):
"""Canvasからのbase64データを処理"""
if not base64_data or base64_data == "":
return "手書きしてください。"
try:
# data:image/png;base64,... の形式から実際のbase64部分を取得
if "," in base64_data:
base64_data = base64_data.split(",")[1]
# base64デコード
image_data = base64.b64decode(base64_data)
image = Image.open(BytesIO(image_data))
return transcribe_handwriting(image)
except Exception as e:
return f"エラーが発生しました: {str(e)}"
# カスタムHTML Canvasとdrawing JavaScript
canvas_html = """
<div id="canvas-container" style="display: flex; flex-direction: column; align-items: center; gap: 10px;">
<canvas id="sketch-canvas" width="600" height="400"
style="border: 2px solid #333; background: white; cursor: crosshair; touch-action: none;"></canvas>
<button id="clear-btn" type="button"
style="padding: 8px 20px; background: #ff4444; color: white; border: none; border-radius: 5px; cursor: pointer;">
クリア
</button>
</div>
<script>
(function() {
const canvas = document.getElementById('sketch-canvas');
const ctx = canvas.getContext('2d');
const clearBtn = document.getElementById('clear-btn');
let isDrawing = false;
let lastX = 0;
let lastY = 0;
// 初期化
ctx.fillStyle = 'white';
ctx.fillRect(0, 0, canvas.width, canvas.height);
ctx.strokeStyle = '#000000';
ctx.lineWidth = 3;
ctx.lineCap = 'round';
ctx.lineJoin = 'round';
function getPos(e) {
const rect = canvas.getBoundingClientRect();
const scaleX = canvas.width / rect.width;
const scaleY = canvas.height / rect.height;
if (e.touches) {
return {
x: (e.touches[0].clientX - rect.left) * scaleX,
y: (e.touches[0].clientY - rect.top) * scaleY
};
}
return {
x: (e.clientX - rect.left) * scaleX,
y: (e.clientY - rect.top) * scaleY
};
}
function startDrawing(e) {
isDrawing = true;
const pos = getPos(e);
lastX = pos.x;
lastY = pos.y;
e.preventDefault();
}
function draw(e) {
if (!isDrawing) return;
e.preventDefault();
const pos = getPos(e);
ctx.beginPath();
ctx.moveTo(lastX, lastY);
ctx.lineTo(pos.x, pos.y);
ctx.stroke();
lastX = pos.x;
lastY = pos.y;
}
function stopDrawing(e) {
isDrawing = false;
e.preventDefault();
}
// Mouse events
canvas.addEventListener('mousedown', startDrawing);
canvas.addEventListener('mousemove', draw);
canvas.addEventListener('mouseup', stopDrawing);
canvas.addEventListener('mouseout', stopDrawing);
// Touch events
canvas.addEventListener('touchstart', startDrawing);
canvas.addEventListener('touchmove', draw);
canvas.addEventListener('touchend', stopDrawing);
// Clear button
clearBtn.addEventListener('click', function() {
ctx.fillStyle = 'white';
ctx.fillRect(0, 0, canvas.width, canvas.height);
});
})();
</script>
"""
# JavaScriptでCanvasからbase64を取得
get_canvas_js = """
async (current_value) => {
const canvas = document.getElementById('sketch-canvas');
if (canvas) {
return canvas.toDataURL('image/png');
}
return '';
}
"""
# Gradioインターフェースの構築
with gr.Blocks(title="手書き文字認識システム") as demo:
gr.Markdown(
"""
# 手書き文字認識システム
**Qwen3-VL-4B-Instruct** を使用した手書き文字のOCR(光学文字認識)システムです。
日本語とアルファベットの両方に対応しています。
"""
)
with gr.Tab("画像アップロード"):
gr.Markdown("手書き文字が書かれた画像をアップロードしてください。")
with gr.Row():
with gr.Column():
upload_image = gr.Image(
label="画像をアップロード",
type="pil",
height=400,
)
upload_btn = gr.Button("文字を認識", variant="primary")
with gr.Column():
upload_output = gr.Textbox(
label="認識結果",
lines=10,
)
upload_btn.click(
fn=transcribe_handwriting,
inputs=upload_image,
outputs=upload_output,
)
with gr.Tab("手書き入力"):
gr.Markdown("マウスやタッチで文字を書いてください。")
# カスタムCanvas
canvas = gr.HTML(canvas_html)
# 隠しテキストボックス(Canvas dataを受け取る)
canvas_data = gr.Textbox(visible=False, elem_id="canvas-data")
with gr.Row():
sketch_btn = gr.Button("文字を認識", variant="primary")
sketch_output = gr.Textbox(
label="認識結果",
lines=10,
)
# ボタンクリック時にJSでcanvasデータを取得してから処理
sketch_btn.click(
fn=process_canvas,
inputs=[canvas_data],
outputs=[sketch_output],
js=get_canvas_js,
)
gr.Markdown(
"""
---
### 使い方のヒント
- **画像アップロード**: スキャンした手書きメモや写真をアップロード
- **手書き入力**: マウスで直接文字を書いて認識をテスト
- 認識精度は文字の明瞭さに依存します
*Powered by Qwen3-VL-4B-Instruct*
"""
)
if __name__ == "__main__":
demo.launch()