import os import inspect import torch from PIL import Image from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration import gradio as gr try: import spaces except ImportError: # Lets the app still run locally without ZeroGPU. class _SpacesFallback: @staticmethod def GPU(*args, **kwargs): def decorator(fn): return fn return decorator spaces = _SpacesFallback() ORIGINAL_MODEL_ID = "openbmb/MiniCPM-V-4.6" FINETUNED_MODEL_ID = "jon-fernandes/noteworthy" NOTES_PROMPT = "Transcribe the musical notes in this image. Return only the transcription." CAMERA_CAPTURE_JS = """ function () { const attachTapCapture = () => { const root = document.getElementById("sheet-music-input"); if (!root || root.dataset.tapCaptureReady === "1") { return Boolean(root); } root.dataset.tapCaptureReady = "1"; root.addEventListener("click", (event) => { if (event.target.closest("button, input, select, textarea, a")) { return; } if (!root.querySelector("video")) { return; } const buttons = Array.from(root.querySelectorAll("button")); const captureButton = buttons.find((button) => { const text = [ button.textContent, button.getAttribute("aria-label"), button.getAttribute("title"), ].filter(Boolean).join(" ").toLowerCase(); return text.includes("capture") || text.includes("photo") || text.includes("snapshot"); }); if (captureButton) { captureButton.click(); } }); return true; }; if (!attachTapCapture()) { const timer = setInterval(() => { if (attachTapCapture()) { clearInterval(timer); } }, 300); setTimeout(() => { clearInterval(timer); }, 5000); } } """ CUSTOM_CSS = """ #sheet-music-input video, #sheet-music-input canvas, #sheet-music-input img { cursor: pointer; } """ def env_flag(name: str, default: bool = False) -> bool: value = os.environ.get(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "on"} def supports_keyword(callable_obj, keyword): try: signature = inspect.signature(callable_obj) except (TypeError, ValueError): return False return keyword in signature.parameters # Important for ZeroGPU: # Do NOT warm up at startup by default. GPU is only allocated inside @spaces.GPU. ENABLE_MODEL_WARMUP = env_flag("NOTEWORTHY_WARMUP", False) MODEL_LOAD_ERRORS = {} print("Loading processor...") processor = AutoProcessor.from_pretrained( ORIGINAL_MODEL_ID, trust_remote_code=True, ) def load_local_model(label, model_id): print(f"Loading {label} model...") try: model = MiniCPMV4_6ForConditionalGeneration.from_pretrained( model_id, torch_dtype=torch.bfloat16, attn_implementation="sdpa", trust_remote_code=True, low_cpu_mem_usage=True, ) # Important for ZeroGPU: # Hugging Face recommends placing the model on CUDA at module level. # ZeroGPU uses CUDA emulation during startup and real CUDA inside @spaces.GPU. model = model.to("cuda").eval() print(f"{label} model loaded.") return model except Exception as e: message = f"{type(e).__name__}: {e}" MODEL_LOAD_ERRORS[label] = message print(f"Failed to load {label} model: {message}") return None finetuned_model = load_local_model("fine-tuned", FINETUNED_MODEL_ID) print("Startup complete.") def _get_model_device(model): try: return next(model.parameters()).device except StopIteration: return torch.device("cuda") def _move_model_inputs(inputs, device): moved = {} for key, value in inputs.items(): if isinstance(value, torch.Tensor): if torch.is_floating_point(value): value = value.to(dtype=torch.bfloat16) moved[key] = value.to(device) else: moved[key] = value return moved def _build_model_inputs(image: Image.Image): input_variants = [ ( [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": NOTES_PROMPT}, ], } ], {}, ), ( [ { "role": "user", "content": [ {"type": "image"}, {"type": "text", "text": NOTES_PROMPT}, ], } ], {"images": [image]}, ), ] errors = [] for messages, extra_processor_kwargs in input_variants: try: inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", enable_thinking=False, processor_kwargs={ **extra_processor_kwargs, "downsample_mode": "4x", "max_slice_nums": 9, "use_image_id": True, }, ) if hasattr(inputs, "items"): return dict(inputs) errors.append(f"Unexpected input type: {type(inputs).__name__}") except TypeError as e: errors.append(str(e)) try: inputs = processor.apply_chat_template( messages, add_generation_prompt=True, tokenize=True, return_dict=True, return_tensors="pt", processor_kwargs={ **extra_processor_kwargs, "downsample_mode": "4x", "max_slice_nums": 9, "use_image_id": True, }, ) if hasattr(inputs, "items"): return dict(inputs) errors.append(f"Unexpected input type: {type(inputs).__name__}") except Exception as fallback_error: errors.append(str(fallback_error)) except Exception as e: errors.append(str(e)) raise RuntimeError("; ".join(errors[-4:])) def generate_model_text(model, image: Image.Image, max_new_tokens: int): if model is None: raise RuntimeError( MODEL_LOAD_ERRORS.get("fine-tuned", "Fine-tuned model failed to load.") ) device = _get_model_device(model) inputs = _move_model_inputs(_build_model_inputs(image), device) with torch.inference_mode(): generated_ids = model.generate( **inputs, max_new_tokens=max_new_tokens, do_sample=False, num_beams=1, downsample_mode="4x", ) input_ids = inputs.get("input_ids") if ( isinstance(input_ids, torch.Tensor) and isinstance(generated_ids, torch.Tensor) and generated_ids.shape[-1] > input_ids.shape[-1] ): generated_ids = generated_ids[:, input_ids.shape[-1]:] return processor.tokenizer.batch_decode( generated_ids, skip_special_tokens=True, )[0].strip() def postprocess_finetuned(text: str) -> str: text = text.replace("note-", "") text = text.replace("barline", "|") text = text.replace("whole", "semibreve") text = text.replace("half", "minim") text = text.replace("quarter", "crotchet") text = text.replace("eighth", "quaver") text = text.replace("sixteenth", "semiquaver") text = text.replace("thirtysecond", "demisemiquaver") return text def warmup_models(): if not ENABLE_MODEL_WARMUP: print("Model warmup disabled.") return warmup_path = "examples/000100005-1_1_1.png" if not os.path.exists(warmup_path): print("Skipping model warmup; example image is missing.") return if finetuned_model is None: print("Skipping warmup; fine-tuned model failed to load.") return print("Warming up fine-tuned model...") image = Image.open(warmup_path).convert("RGB") try: generate_model_text(finetuned_model, image, max_new_tokens=8) except Exception as e: print(f"Warmup failed: {type(e).__name__}: {e}") print("Model warmup complete.") @spaces.GPU(duration=180) def predict_finetuned(image_path): if image_path is None: yield "Please upload an image." return if finetuned_model is None: yield f"[Error: fine-tuned model failed to load: {MODEL_LOAD_ERRORS.get('fine-tuned', 'unknown error')}]" return try: image = Image.open(image_path).convert("RGB") text = generate_model_text( finetuned_model, image, max_new_tokens=1024, ) yield postprocess_finetuned(text) except Exception as e: yield f"[Error: {type(e).__name__}: {e}]" finally: if torch.cuda.is_available(): torch.cuda.empty_cache() warmup_models() blocks_kwargs = { "title": "Noteworthy — Sheet Music Transcription", "theme": gr.themes.Soft(), } if supports_keyword(gr.Blocks, "css"): blocks_kwargs["css"] = CUSTOM_CSS if supports_keyword(gr.Blocks, "js"): blocks_kwargs["js"] = CAMERA_CAPTURE_JS with gr.Blocks(**blocks_kwargs) as demo: gr.Markdown( """ # Noteworthy Sheet Music Transcription Take a photo or upload sheet music, then click **Transcribe Music**. """ ) image_input = gr.Image( type="filepath", label="Sheet Music Image", show_label=False, sources=["upload", "webcam", "clipboard"], webcam_options=gr.WebcamOptions( mirror=False, constraints={"facingMode": "environment"}, ), placeholder="Upload sheet music, then click Transcribe Music.", elem_id="sheet-music-input", ) gr.Examples( examples=[ ["examples/000100005-1_1_1.png"], ["examples/000100014-1_1_1.png"], ["examples/000100059-1_1_1.png"], ], inputs=image_input, ) notes_btn = gr.Button( "Transcribe Music", variant="primary", size="lg", ) finetuned_output = gr.Textbox( label="Noteworthy Fine-tuned", lines=20, ) notes_btn.click( fn=predict_finetuned, inputs=[image_input], outputs=[finetuned_output], api_name="transcribe_music", ) demo.queue(max_size=20) launch_kwargs = { "server_name": os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"), "server_port": int(os.environ.get("GRADIO_SERVER_PORT", "7860")), "share": env_flag("GRADIO_SHARE"), } if supports_keyword(demo.launch, "mcp_server"): launch_kwargs["mcp_server"] = True demo.launch(**launch_kwargs)