Spaces:
Running on Zero
Running on Zero
| import os | |
| import inspect | |
| import torch | |
| from PIL import Image | |
| from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration | |
| import gradio as gr | |
| try: | |
| import spaces | |
| except ImportError: | |
| # Lets the app still run locally without ZeroGPU. | |
| class _SpacesFallback: | |
| def GPU(*args, **kwargs): | |
| def decorator(fn): | |
| return fn | |
| return decorator | |
| spaces = _SpacesFallback() | |
| ORIGINAL_MODEL_ID = "openbmb/MiniCPM-V-4.6" | |
| FINETUNED_MODEL_ID = "jon-fernandes/noteworthy" | |
| NOTES_PROMPT = "Transcribe the musical notes in this image. Return only the transcription." | |
| CAMERA_CAPTURE_JS = """ | |
| function () { | |
| const attachTapCapture = () => { | |
| const root = document.getElementById("sheet-music-input"); | |
| if (!root || root.dataset.tapCaptureReady === "1") { | |
| return Boolean(root); | |
| } | |
| root.dataset.tapCaptureReady = "1"; | |
| root.addEventListener("click", (event) => { | |
| if (event.target.closest("button, input, select, textarea, a")) { | |
| return; | |
| } | |
| if (!root.querySelector("video")) { | |
| return; | |
| } | |
| const buttons = Array.from(root.querySelectorAll("button")); | |
| const captureButton = buttons.find((button) => { | |
| const text = [ | |
| button.textContent, | |
| button.getAttribute("aria-label"), | |
| button.getAttribute("title"), | |
| ].filter(Boolean).join(" ").toLowerCase(); | |
| return text.includes("capture") || text.includes("photo") || text.includes("snapshot"); | |
| }); | |
| if (captureButton) { | |
| captureButton.click(); | |
| } | |
| }); | |
| return true; | |
| }; | |
| if (!attachTapCapture()) { | |
| const timer = setInterval(() => { | |
| if (attachTapCapture()) { | |
| clearInterval(timer); | |
| } | |
| }, 300); | |
| setTimeout(() => { | |
| clearInterval(timer); | |
| }, 5000); | |
| } | |
| } | |
| """ | |
| CUSTOM_CSS = """ | |
| #sheet-music-input video, | |
| #sheet-music-input canvas, | |
| #sheet-music-input img { | |
| cursor: pointer; | |
| } | |
| """ | |
| def env_flag(name: str, default: bool = False) -> bool: | |
| value = os.environ.get(name) | |
| if value is None: | |
| return default | |
| return value.strip().lower() in {"1", "true", "yes", "on"} | |
| def supports_keyword(callable_obj, keyword): | |
| try: | |
| signature = inspect.signature(callable_obj) | |
| except (TypeError, ValueError): | |
| return False | |
| return keyword in signature.parameters | |
| # Important for ZeroGPU: | |
| # Do NOT warm up at startup by default. GPU is only allocated inside @spaces.GPU. | |
| ENABLE_MODEL_WARMUP = env_flag("NOTEWORTHY_WARMUP", False) | |
| MODEL_LOAD_ERRORS = {} | |
| print("Loading processor...") | |
| processor = AutoProcessor.from_pretrained( | |
| ORIGINAL_MODEL_ID, | |
| trust_remote_code=True, | |
| ) | |
| def load_local_model(label, model_id): | |
| print(f"Loading {label} model...") | |
| try: | |
| model = MiniCPMV4_6ForConditionalGeneration.from_pretrained( | |
| model_id, | |
| torch_dtype=torch.bfloat16, | |
| attn_implementation="sdpa", | |
| trust_remote_code=True, | |
| low_cpu_mem_usage=True, | |
| ) | |
| # Important for ZeroGPU: | |
| # Hugging Face recommends placing the model on CUDA at module level. | |
| # ZeroGPU uses CUDA emulation during startup and real CUDA inside @spaces.GPU. | |
| model = model.to("cuda").eval() | |
| print(f"{label} model loaded.") | |
| return model | |
| except Exception as e: | |
| message = f"{type(e).__name__}: {e}" | |
| MODEL_LOAD_ERRORS[label] = message | |
| print(f"Failed to load {label} model: {message}") | |
| return None | |
| finetuned_model = load_local_model("fine-tuned", FINETUNED_MODEL_ID) | |
| print("Startup complete.") | |
| def _get_model_device(model): | |
| try: | |
| return next(model.parameters()).device | |
| except StopIteration: | |
| return torch.device("cuda") | |
| def _move_model_inputs(inputs, device): | |
| moved = {} | |
| for key, value in inputs.items(): | |
| if isinstance(value, torch.Tensor): | |
| if torch.is_floating_point(value): | |
| value = value.to(dtype=torch.bfloat16) | |
| moved[key] = value.to(device) | |
| else: | |
| moved[key] = value | |
| return moved | |
| def _build_model_inputs(image: Image.Image): | |
| input_variants = [ | |
| ( | |
| [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": NOTES_PROMPT}, | |
| ], | |
| } | |
| ], | |
| {}, | |
| ), | |
| ( | |
| [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image"}, | |
| {"type": "text", "text": NOTES_PROMPT}, | |
| ], | |
| } | |
| ], | |
| {"images": [image]}, | |
| ), | |
| ] | |
| errors = [] | |
| for messages, extra_processor_kwargs in input_variants: | |
| try: | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| enable_thinking=False, | |
| processor_kwargs={ | |
| **extra_processor_kwargs, | |
| "downsample_mode": "4x", | |
| "max_slice_nums": 9, | |
| "use_image_id": True, | |
| }, | |
| ) | |
| if hasattr(inputs, "items"): | |
| return dict(inputs) | |
| errors.append(f"Unexpected input type: {type(inputs).__name__}") | |
| except TypeError as e: | |
| errors.append(str(e)) | |
| try: | |
| inputs = processor.apply_chat_template( | |
| messages, | |
| add_generation_prompt=True, | |
| tokenize=True, | |
| return_dict=True, | |
| return_tensors="pt", | |
| processor_kwargs={ | |
| **extra_processor_kwargs, | |
| "downsample_mode": "4x", | |
| "max_slice_nums": 9, | |
| "use_image_id": True, | |
| }, | |
| ) | |
| if hasattr(inputs, "items"): | |
| return dict(inputs) | |
| errors.append(f"Unexpected input type: {type(inputs).__name__}") | |
| except Exception as fallback_error: | |
| errors.append(str(fallback_error)) | |
| except Exception as e: | |
| errors.append(str(e)) | |
| raise RuntimeError("; ".join(errors[-4:])) | |
| def generate_model_text(model, image: Image.Image, max_new_tokens: int): | |
| if model is None: | |
| raise RuntimeError( | |
| MODEL_LOAD_ERRORS.get("fine-tuned", "Fine-tuned model failed to load.") | |
| ) | |
| device = _get_model_device(model) | |
| inputs = _move_model_inputs(_build_model_inputs(image), device) | |
| with torch.inference_mode(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| do_sample=False, | |
| num_beams=1, | |
| downsample_mode="4x", | |
| ) | |
| input_ids = inputs.get("input_ids") | |
| if ( | |
| isinstance(input_ids, torch.Tensor) | |
| and isinstance(generated_ids, torch.Tensor) | |
| and generated_ids.shape[-1] > input_ids.shape[-1] | |
| ): | |
| generated_ids = generated_ids[:, input_ids.shape[-1]:] | |
| return processor.tokenizer.batch_decode( | |
| generated_ids, | |
| skip_special_tokens=True, | |
| )[0].strip() | |
| def postprocess_finetuned(text: str) -> str: | |
| text = text.replace("note-", "") | |
| text = text.replace("barline", "|") | |
| text = text.replace("whole", "semibreve") | |
| text = text.replace("half", "minim") | |
| text = text.replace("quarter", "crotchet") | |
| text = text.replace("eighth", "quaver") | |
| text = text.replace("sixteenth", "semiquaver") | |
| text = text.replace("thirtysecond", "demisemiquaver") | |
| return text | |
| def warmup_models(): | |
| if not ENABLE_MODEL_WARMUP: | |
| print("Model warmup disabled.") | |
| return | |
| warmup_path = "examples/000100005-1_1_1.png" | |
| if not os.path.exists(warmup_path): | |
| print("Skipping model warmup; example image is missing.") | |
| return | |
| if finetuned_model is None: | |
| print("Skipping warmup; fine-tuned model failed to load.") | |
| return | |
| print("Warming up fine-tuned model...") | |
| image = Image.open(warmup_path).convert("RGB") | |
| try: | |
| generate_model_text(finetuned_model, image, max_new_tokens=8) | |
| except Exception as e: | |
| print(f"Warmup failed: {type(e).__name__}: {e}") | |
| print("Model warmup complete.") | |
| def predict_finetuned(image_path): | |
| if image_path is None: | |
| yield "Please upload an image." | |
| return | |
| if finetuned_model is None: | |
| yield f"[Error: fine-tuned model failed to load: {MODEL_LOAD_ERRORS.get('fine-tuned', 'unknown error')}]" | |
| return | |
| try: | |
| image = Image.open(image_path).convert("RGB") | |
| text = generate_model_text( | |
| finetuned_model, | |
| image, | |
| max_new_tokens=1024, | |
| ) | |
| yield postprocess_finetuned(text) | |
| except Exception as e: | |
| yield f"[Error: {type(e).__name__}: {e}]" | |
| finally: | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| warmup_models() | |
| blocks_kwargs = { | |
| "title": "Noteworthy — Sheet Music Transcription", | |
| "theme": gr.themes.Soft(), | |
| } | |
| if supports_keyword(gr.Blocks, "css"): | |
| blocks_kwargs["css"] = CUSTOM_CSS | |
| if supports_keyword(gr.Blocks, "js"): | |
| blocks_kwargs["js"] = CAMERA_CAPTURE_JS | |
| with gr.Blocks(**blocks_kwargs) as demo: | |
| gr.Markdown( | |
| """ | |
| # Noteworthy | |
| Sheet Music Transcription | |
| Take a photo or upload sheet music, then click **Transcribe Music**. | |
| """ | |
| ) | |
| image_input = gr.Image( | |
| type="filepath", | |
| label="Sheet Music Image", | |
| show_label=False, | |
| sources=["upload", "webcam", "clipboard"], | |
| webcam_options=gr.WebcamOptions( | |
| mirror=False, | |
| constraints={"facingMode": "environment"}, | |
| ), | |
| placeholder="Upload sheet music, then click Transcribe Music.", | |
| elem_id="sheet-music-input", | |
| ) | |
| gr.Examples( | |
| examples=[ | |
| ["examples/000100005-1_1_1.png"], | |
| ["examples/000100014-1_1_1.png"], | |
| ["examples/000100059-1_1_1.png"], | |
| ], | |
| inputs=image_input, | |
| ) | |
| notes_btn = gr.Button( | |
| "Transcribe Music", | |
| variant="primary", | |
| size="lg", | |
| ) | |
| finetuned_output = gr.Textbox( | |
| label="Noteworthy Fine-tuned", | |
| lines=20, | |
| ) | |
| notes_btn.click( | |
| fn=predict_finetuned, | |
| inputs=[image_input], | |
| outputs=[finetuned_output], | |
| api_name="transcribe_music", | |
| ) | |
| demo.queue(max_size=20) | |
| launch_kwargs = { | |
| "server_name": os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"), | |
| "server_port": int(os.environ.get("GRADIO_SERVER_PORT", "7860")), | |
| "share": env_flag("GRADIO_SHARE"), | |
| } | |
| if supports_keyword(demo.launch, "mcp_server"): | |
| launch_kwargs["mcp_server"] = True | |
| demo.launch(**launch_kwargs) |