Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import Qwen3VLForConditionalGeneration, AutoProcessor | |
| from PIL import Image | |
| import numpy as np | |
| import base64 | |
| from io import BytesIO | |
| # モデルとプロセッサの読み込み | |
| model_name = "Qwen/Qwen3-VL-4B-Instruct" | |
| print(f"Loading model: {model_name}") | |
| print(f"CUDA available: {torch.cuda.is_available()}") | |
| # デバイスとdtypeの設定 | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| dtype = torch.bfloat16 | |
| else: | |
| device = "cpu" | |
| dtype = torch.float32 | |
| model = Qwen3VLForConditionalGeneration.from_pretrained( | |
| model_name, | |
| torch_dtype=dtype, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| processor = AutoProcessor.from_pretrained(model_name) | |
| # CPUの場合はモデルを明示的に移動 | |
| if not torch.cuda.is_available(): | |
| model = model.to(device) | |
| print("Model loaded successfully!") | |
| def transcribe_handwriting(image): | |
| """手書き文字画像をOCRで文字起こしする""" | |
| if image is None: | |
| return "画像をアップロードしてください。" | |
| # numpy配列の場合はPIL Imageに変換 | |
| if isinstance(image, np.ndarray): | |
| if len(image.shape) == 2: | |
| image = Image.fromarray(image).convert('RGB') | |
| else: | |
| image = Image.fromarray(image) | |
| # PIL Imageの場合 | |
| if isinstance(image, Image.Image): | |
| if image.mode == 'RGBA': | |
| background = Image.new('RGB', image.size, (255, 255, 255)) | |
| background.paste(image, mask=image.split()[3]) | |
| image = background | |
| elif image.mode == 'L': | |
| image = image.convert('RGB') | |
| elif image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # メッセージの構築 | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| { | |
| "type": "image", | |
| "image": image, | |
| }, | |
| { | |
| "type": "text", | |
| "text": "この画像に書かれている手書きの文字を正確に読み取って、テキストとして出力してください。日本語とアルファベットの両方に対応してください。文字以外の説明は不要です。読み取った文字のみを出力してください。", | |
| }, | |
| ], | |
| } | |
| ] | |
| # 入力の準備(Qwen3-VL用) | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| ) | |
| inputs.pop("token_type_ids", None) | |
| inputs = inputs.to(model.device) | |
| # 推論 | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| do_sample=False, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs["input_ids"], generated_ids) | |
| ] | |
| output_text = processor.batch_decode( | |
| generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| ) | |
| return output_text[0] if output_text else "文字を認識できませんでした。" | |
| def process_canvas(base64_data): | |
| """Canvasからのbase64データを処理""" | |
| if not base64_data or base64_data == "": | |
| return "手書きしてください。" | |
| try: | |
| # data:image/png;base64,... の形式から実際のbase64部分を取得 | |
| if "," in base64_data: | |
| base64_data = base64_data.split(",")[1] | |
| # base64デコード | |
| image_data = base64.b64decode(base64_data) | |
| image = Image.open(BytesIO(image_data)) | |
| return transcribe_handwriting(image) | |
| except Exception as e: | |
| return f"エラーが発生しました: {str(e)}" | |
| # カスタムHTML Canvasとdrawing JavaScript | |
| canvas_html = """ | |
| <div id="canvas-container" style="display: flex; flex-direction: column; align-items: center; gap: 10px;"> | |
| <canvas id="sketch-canvas" width="600" height="400" | |
| style="border: 2px solid #333; background: white; cursor: crosshair; touch-action: none;"></canvas> | |
| <button id="clear-btn" type="button" | |
| style="padding: 8px 20px; background: #ff4444; color: white; border: none; border-radius: 5px; cursor: pointer;"> | |
| クリア | |
| </button> | |
| </div> | |
| <script> | |
| (function() { | |
| const canvas = document.getElementById('sketch-canvas'); | |
| const ctx = canvas.getContext('2d'); | |
| const clearBtn = document.getElementById('clear-btn'); | |
| let isDrawing = false; | |
| let lastX = 0; | |
| let lastY = 0; | |
| // 初期化 | |
| ctx.fillStyle = 'white'; | |
| ctx.fillRect(0, 0, canvas.width, canvas.height); | |
| ctx.strokeStyle = '#000000'; | |
| ctx.lineWidth = 3; | |
| ctx.lineCap = 'round'; | |
| ctx.lineJoin = 'round'; | |
| function getPos(e) { | |
| const rect = canvas.getBoundingClientRect(); | |
| const scaleX = canvas.width / rect.width; | |
| const scaleY = canvas.height / rect.height; | |
| if (e.touches) { | |
| return { | |
| x: (e.touches[0].clientX - rect.left) * scaleX, | |
| y: (e.touches[0].clientY - rect.top) * scaleY | |
| }; | |
| } | |
| return { | |
| x: (e.clientX - rect.left) * scaleX, | |
| y: (e.clientY - rect.top) * scaleY | |
| }; | |
| } | |
| function startDrawing(e) { | |
| isDrawing = true; | |
| const pos = getPos(e); | |
| lastX = pos.x; | |
| lastY = pos.y; | |
| e.preventDefault(); | |
| } | |
| function draw(e) { | |
| if (!isDrawing) return; | |
| e.preventDefault(); | |
| const pos = getPos(e); | |
| ctx.beginPath(); | |
| ctx.moveTo(lastX, lastY); | |
| ctx.lineTo(pos.x, pos.y); | |
| ctx.stroke(); | |
| lastX = pos.x; | |
| lastY = pos.y; | |
| } | |
| function stopDrawing(e) { | |
| isDrawing = false; | |
| e.preventDefault(); | |
| } | |
| // Mouse events | |
| canvas.addEventListener('mousedown', startDrawing); | |
| canvas.addEventListener('mousemove', draw); | |
| canvas.addEventListener('mouseup', stopDrawing); | |
| canvas.addEventListener('mouseout', stopDrawing); | |
| // Touch events | |
| canvas.addEventListener('touchstart', startDrawing); | |
| canvas.addEventListener('touchmove', draw); | |
| canvas.addEventListener('touchend', stopDrawing); | |
| // Clear button | |
| clearBtn.addEventListener('click', function() { | |
| ctx.fillStyle = 'white'; | |
| ctx.fillRect(0, 0, canvas.width, canvas.height); | |
| }); | |
| })(); | |
| </script> | |
| """ | |
| # JavaScriptでCanvasからbase64を取得 | |
| get_canvas_js = """ | |
| async (current_value) => { | |
| const canvas = document.getElementById('sketch-canvas'); | |
| if (canvas) { | |
| return canvas.toDataURL('image/png'); | |
| } | |
| return ''; | |
| } | |
| """ | |
| # Gradioインターフェースの構築 | |
| with gr.Blocks(title="手書き文字認識システム") as demo: | |
| gr.Markdown( | |
| """ | |
| # 手書き文字認識システム | |
| **Qwen3-VL-4B-Instruct** を使用した手書き文字のOCR(光学文字認識)システムです。 | |
| 日本語とアルファベットの両方に対応しています。 | |
| """ | |
| ) | |
| with gr.Tab("画像アップロード"): | |
| gr.Markdown("手書き文字が書かれた画像をアップロードしてください。") | |
| with gr.Row(): | |
| with gr.Column(): | |
| upload_image = gr.Image( | |
| label="画像をアップロード", | |
| type="pil", | |
| height=400, | |
| ) | |
| upload_btn = gr.Button("文字を認識", variant="primary") | |
| with gr.Column(): | |
| upload_output = gr.Textbox( | |
| label="認識結果", | |
| lines=10, | |
| ) | |
| upload_btn.click( | |
| fn=transcribe_handwriting, | |
| inputs=upload_image, | |
| outputs=upload_output, | |
| ) | |
| with gr.Tab("手書き入力"): | |
| gr.Markdown("マウスやタッチで文字を書いてください。") | |
| # カスタムCanvas | |
| canvas = gr.HTML(canvas_html) | |
| # 隠しテキストボックス(Canvas dataを受け取る) | |
| canvas_data = gr.Textbox(visible=False, elem_id="canvas-data") | |
| with gr.Row(): | |
| sketch_btn = gr.Button("文字を認識", variant="primary") | |
| sketch_output = gr.Textbox( | |
| label="認識結果", | |
| lines=10, | |
| ) | |
| # ボタンクリック時にJSでcanvasデータを取得してから処理 | |
| sketch_btn.click( | |
| fn=process_canvas, | |
| inputs=[canvas_data], | |
| outputs=[sketch_output], | |
| js=get_canvas_js, | |
| ) | |
| gr.Markdown( | |
| """ | |
| --- | |
| ### 使い方のヒント | |
| - **画像アップロード**: スキャンした手書きメモや写真をアップロード | |
| - **手書き入力**: マウスで直接文字を書いて認識をテスト | |
| - 認識精度は文字の明瞭さに依存します | |
| *Powered by Qwen3-VL-4B-Instruct* | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |