Spaces:
Sleeping
Sleeping
| from transformers import AutoImageProcessor, AutoModelForImageClassification | |
| from PIL import Image | |
| import streamlit as st | |
| import torch | |
| from streamlit_drawable_canvas import st_canvas | |
| st.set_page_config(page_title="Draw Something!", layout="centered") | |
| if "prediction" not in st.session_state: | |
| st.session_state["prediction"] = "Draw something!" | |
| st.markdown(f"<h1 style='text-align: center;'>{st.session_state['prediction']}</h1>", unsafe_allow_html=True) | |
| processor = AutoImageProcessor.from_pretrained("kmewhort/resnet34-sketch-classifier") | |
| model = AutoModelForImageClassification.from_pretrained("kmewhort/resnet34-sketch-classifier") | |
| canvas = st_canvas( | |
| stroke_width=5, | |
| stroke_color="#000000", | |
| background_color="#FFFFFF", | |
| height=700, | |
| width=700, | |
| drawing_mode="freedraw", | |
| ) | |
| def predict_drawing(): | |
| if canvas.image_data is not None: | |
| drawing = canvas.image_data.astype("uint8") | |
| image = Image.fromarray(drawing).convert("L") | |
| image = image.convert("RGB") | |
| inputs = processor(images=image, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = model(**inputs).logits | |
| predicted_class_idx = logits.argmax(-1).item() | |
| st.session_state["prediction"] = model.config.id2label[predicted_class_idx] | |
| else: | |
| st.session_state["prediction"] = "Draw something!" | |
| if canvas.image_data is not None: | |
| predict_drawing() | |
| css = ''' | |
| <style> | |
| section.stMain { | |
| overflow: hidden; | |
| } | |
| </style> | |
| ''' | |
| st.markdown(css, unsafe_allow_html=True) | |