noteworthy / app.py
jon-fernandes's picture
Update for using ZeroGPU
c75b065 verified
raw
history blame contribute delete
11.2 kB
import os
import inspect
import torch
from PIL import Image
from transformers import AutoProcessor, MiniCPMV4_6ForConditionalGeneration
import gradio as gr
try:
import spaces
except ImportError:
# Lets the app still run locally without ZeroGPU.
class _SpacesFallback:
@staticmethod
def GPU(*args, **kwargs):
def decorator(fn):
return fn
return decorator
spaces = _SpacesFallback()
ORIGINAL_MODEL_ID = "openbmb/MiniCPM-V-4.6"
FINETUNED_MODEL_ID = "jon-fernandes/noteworthy"
NOTES_PROMPT = "Transcribe the musical notes in this image. Return only the transcription."
CAMERA_CAPTURE_JS = """
function () {
const attachTapCapture = () => {
const root = document.getElementById("sheet-music-input");
if (!root || root.dataset.tapCaptureReady === "1") {
return Boolean(root);
}
root.dataset.tapCaptureReady = "1";
root.addEventListener("click", (event) => {
if (event.target.closest("button, input, select, textarea, a")) {
return;
}
if (!root.querySelector("video")) {
return;
}
const buttons = Array.from(root.querySelectorAll("button"));
const captureButton = buttons.find((button) => {
const text = [
button.textContent,
button.getAttribute("aria-label"),
button.getAttribute("title"),
].filter(Boolean).join(" ").toLowerCase();
return text.includes("capture") || text.includes("photo") || text.includes("snapshot");
});
if (captureButton) {
captureButton.click();
}
});
return true;
};
if (!attachTapCapture()) {
const timer = setInterval(() => {
if (attachTapCapture()) {
clearInterval(timer);
}
}, 300);
setTimeout(() => {
clearInterval(timer);
}, 5000);
}
}
"""
CUSTOM_CSS = """
#sheet-music-input video,
#sheet-music-input canvas,
#sheet-music-input img {
cursor: pointer;
}
"""
def env_flag(name: str, default: bool = False) -> bool:
value = os.environ.get(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def supports_keyword(callable_obj, keyword):
try:
signature = inspect.signature(callable_obj)
except (TypeError, ValueError):
return False
return keyword in signature.parameters
# Important for ZeroGPU:
# Do NOT warm up at startup by default. GPU is only allocated inside @spaces.GPU.
ENABLE_MODEL_WARMUP = env_flag("NOTEWORTHY_WARMUP", False)
MODEL_LOAD_ERRORS = {}
print("Loading processor...")
processor = AutoProcessor.from_pretrained(
ORIGINAL_MODEL_ID,
trust_remote_code=True,
)
def load_local_model(label, model_id):
print(f"Loading {label} model...")
try:
model = MiniCPMV4_6ForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
attn_implementation="sdpa",
trust_remote_code=True,
low_cpu_mem_usage=True,
)
# Important for ZeroGPU:
# Hugging Face recommends placing the model on CUDA at module level.
# ZeroGPU uses CUDA emulation during startup and real CUDA inside @spaces.GPU.
model = model.to("cuda").eval()
print(f"{label} model loaded.")
return model
except Exception as e:
message = f"{type(e).__name__}: {e}"
MODEL_LOAD_ERRORS[label] = message
print(f"Failed to load {label} model: {message}")
return None
finetuned_model = load_local_model("fine-tuned", FINETUNED_MODEL_ID)
print("Startup complete.")
def _get_model_device(model):
try:
return next(model.parameters()).device
except StopIteration:
return torch.device("cuda")
def _move_model_inputs(inputs, device):
moved = {}
for key, value in inputs.items():
if isinstance(value, torch.Tensor):
if torch.is_floating_point(value):
value = value.to(dtype=torch.bfloat16)
moved[key] = value.to(device)
else:
moved[key] = value
return moved
def _build_model_inputs(image: Image.Image):
input_variants = [
(
[
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": NOTES_PROMPT},
],
}
],
{},
),
(
[
{
"role": "user",
"content": [
{"type": "image"},
{"type": "text", "text": NOTES_PROMPT},
],
}
],
{"images": [image]},
),
]
errors = []
for messages, extra_processor_kwargs in input_variants:
try:
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
enable_thinking=False,
processor_kwargs={
**extra_processor_kwargs,
"downsample_mode": "4x",
"max_slice_nums": 9,
"use_image_id": True,
},
)
if hasattr(inputs, "items"):
return dict(inputs)
errors.append(f"Unexpected input type: {type(inputs).__name__}")
except TypeError as e:
errors.append(str(e))
try:
inputs = processor.apply_chat_template(
messages,
add_generation_prompt=True,
tokenize=True,
return_dict=True,
return_tensors="pt",
processor_kwargs={
**extra_processor_kwargs,
"downsample_mode": "4x",
"max_slice_nums": 9,
"use_image_id": True,
},
)
if hasattr(inputs, "items"):
return dict(inputs)
errors.append(f"Unexpected input type: {type(inputs).__name__}")
except Exception as fallback_error:
errors.append(str(fallback_error))
except Exception as e:
errors.append(str(e))
raise RuntimeError("; ".join(errors[-4:]))
def generate_model_text(model, image: Image.Image, max_new_tokens: int):
if model is None:
raise RuntimeError(
MODEL_LOAD_ERRORS.get("fine-tuned", "Fine-tuned model failed to load.")
)
device = _get_model_device(model)
inputs = _move_model_inputs(_build_model_inputs(image), device)
with torch.inference_mode():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
num_beams=1,
downsample_mode="4x",
)
input_ids = inputs.get("input_ids")
if (
isinstance(input_ids, torch.Tensor)
and isinstance(generated_ids, torch.Tensor)
and generated_ids.shape[-1] > input_ids.shape[-1]
):
generated_ids = generated_ids[:, input_ids.shape[-1]:]
return processor.tokenizer.batch_decode(
generated_ids,
skip_special_tokens=True,
)[0].strip()
def postprocess_finetuned(text: str) -> str:
text = text.replace("note-", "")
text = text.replace("barline", "|")
text = text.replace("whole", "semibreve")
text = text.replace("half", "minim")
text = text.replace("quarter", "crotchet")
text = text.replace("eighth", "quaver")
text = text.replace("sixteenth", "semiquaver")
text = text.replace("thirtysecond", "demisemiquaver")
return text
def warmup_models():
if not ENABLE_MODEL_WARMUP:
print("Model warmup disabled.")
return
warmup_path = "examples/000100005-1_1_1.png"
if not os.path.exists(warmup_path):
print("Skipping model warmup; example image is missing.")
return
if finetuned_model is None:
print("Skipping warmup; fine-tuned model failed to load.")
return
print("Warming up fine-tuned model...")
image = Image.open(warmup_path).convert("RGB")
try:
generate_model_text(finetuned_model, image, max_new_tokens=8)
except Exception as e:
print(f"Warmup failed: {type(e).__name__}: {e}")
print("Model warmup complete.")
@spaces.GPU(duration=180)
def predict_finetuned(image_path):
if image_path is None:
yield "Please upload an image."
return
if finetuned_model is None:
yield f"[Error: fine-tuned model failed to load: {MODEL_LOAD_ERRORS.get('fine-tuned', 'unknown error')}]"
return
try:
image = Image.open(image_path).convert("RGB")
text = generate_model_text(
finetuned_model,
image,
max_new_tokens=1024,
)
yield postprocess_finetuned(text)
except Exception as e:
yield f"[Error: {type(e).__name__}: {e}]"
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
warmup_models()
blocks_kwargs = {
"title": "Noteworthy — Sheet Music Transcription",
"theme": gr.themes.Soft(),
}
if supports_keyword(gr.Blocks, "css"):
blocks_kwargs["css"] = CUSTOM_CSS
if supports_keyword(gr.Blocks, "js"):
blocks_kwargs["js"] = CAMERA_CAPTURE_JS
with gr.Blocks(**blocks_kwargs) as demo:
gr.Markdown(
"""
# Noteworthy
Sheet Music Transcription
Take a photo or upload sheet music, then click **Transcribe Music**.
"""
)
image_input = gr.Image(
type="filepath",
label="Sheet Music Image",
show_label=False,
sources=["upload", "webcam", "clipboard"],
webcam_options=gr.WebcamOptions(
mirror=False,
constraints={"facingMode": "environment"},
),
placeholder="Upload sheet music, then click Transcribe Music.",
elem_id="sheet-music-input",
)
gr.Examples(
examples=[
["examples/000100005-1_1_1.png"],
["examples/000100014-1_1_1.png"],
["examples/000100059-1_1_1.png"],
],
inputs=image_input,
)
notes_btn = gr.Button(
"Transcribe Music",
variant="primary",
size="lg",
)
finetuned_output = gr.Textbox(
label="Noteworthy Fine-tuned",
lines=20,
)
notes_btn.click(
fn=predict_finetuned,
inputs=[image_input],
outputs=[finetuned_output],
api_name="transcribe_music",
)
demo.queue(max_size=20)
launch_kwargs = {
"server_name": os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
"server_port": int(os.environ.get("GRADIO_SERVER_PORT", "7860")),
"share": env_flag("GRADIO_SHARE"),
}
if supports_keyword(demo.launch, "mcp_server"):
launch_kwargs["mcp_server"] = True
demo.launch(**launch_kwargs)