SVHN-V1-ResNet / app.py
LearnWaterFlow's picture
Update app.py
e42d6df verified
import gradio as gr
from transformers import pipeline
from PIL import Image
pipe = pipeline(
"image-classification",
model="Dawntasy/SVHN-V1-ResNet",
trust_remote_code=True
)
def predict(img):
if img is None:
return None
if isinstance(img, dict):
if "composite" in img:
img = img["composite"]
elif "background" in img:
img = img["background"]
img = img.convert("RGB")
results = pipe(img)
return {res["label"]: res["score"] for res in results}
with gr.Blocks() as demo:
gr.Markdown("# 🔢SVHN-V1-ResNet Demo")
gr.Markdown("This model was trained on **Street View House Numbers**. Draw in the sketchpad for live predictions or upload a photo!")
with gr.Tabs():
with gr.TabItem("Live Sketchpad"):
with gr.Row():
with gr.Column():
sketch_input = gr.Sketchpad(type="pil", label="Draw a Digit (0-9)")
with gr.Column():
sketch_output = gr.Label(num_top_classes=5, label="Live Prediction")
sketch_input.change(predict, inputs=sketch_input, outputs=sketch_output)
with gr.TabItem("Upload Image"):
with gr.Row():
with gr.Column():
file_input = gr.Image(type="pil", label="Upload Cropped Digit")
upload_button = gr.Button("Classify Upload")
with gr.Column():
file_output = gr.Label(num_top_classes=5, label="Result")
upload_button.click(predict, inputs=file_input, outputs=file_output)
if __name__ == "__main__":
demo.launch(theme=gr.themes.Soft())