Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import tensorflow as tf | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import os | |
| import requests | |
| import tempfile | |
| # Function to download the model from Hugging Face | |
| def download_model_from_hf(model_path, local_dir): | |
| """Download model files from Hugging Face""" | |
| # Create a temporary directory to store the model | |
| os.makedirs(local_dir, exist_ok=True) | |
| # Extract the repository and file path from the URL | |
| # Example URL: https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model | |
| parts = model_path.split('/') | |
| repo_id = f"{parts[3]}/{parts[4]}" | |
| file_path = '/'.join(parts[6:]) | |
| # Download the model files | |
| api_url = f"https://huggingface.co/api/models/{repo_id}/revision/main/files/{file_path}" | |
| response = requests.get(api_url) | |
| if response.status_code == 200: | |
| # Download the saved_model.pb file | |
| saved_model_pb_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/saved_model.pb" | |
| pb_response = requests.get(saved_model_pb_url) | |
| if pb_response.status_code == 200: | |
| with open(os.path.join(local_dir, "saved_model.pb"), "wb") as f: | |
| f.write(pb_response.content) | |
| # Download the variables directory | |
| variables_dir = os.path.join(local_dir, "variables") | |
| os.makedirs(variables_dir, exist_ok=True) | |
| # Download variables.data-00000-of-00001 | |
| variables_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.data-00000-of-00001" | |
| var_response = requests.get(variables_url) | |
| if var_response.status_code == 200: | |
| with open(os.path.join(variables_dir, "variables.data-00000-of-00001"), "wb") as f: | |
| f.write(var_response.content) | |
| # Download variables.index | |
| index_url = f"https://huggingface.co/{repo_id}/resolve/main/{file_path}/variables/variables.index" | |
| index_response = requests.get(index_url) | |
| if index_response.status_code == 200: | |
| with open(os.path.join(variables_dir, "variables.index"), "wb") as f: | |
| f.write(index_response.content) | |
| return True | |
| else: | |
| print(f"Failed to download model: {response.status_code}") | |
| return False | |
| # Create a temporary directory for the model | |
| MODEL_PATH = "https://huggingface.co/nivashuggingface/digit-recognition/resolve/main/saved_model" | |
| LOCAL_MODEL_DIR = os.path.join(tempfile.gettempdir(), "digit_recognition_model") | |
| # Download the model if it doesn't exist locally | |
| if not os.path.exists(os.path.join(LOCAL_MODEL_DIR, "saved_model.pb")): | |
| print("Downloading model from Hugging Face...") | |
| download_model_from_hf(MODEL_PATH, LOCAL_MODEL_DIR) | |
| # Load the model from local directory | |
| print(f"Loading model from {LOCAL_MODEL_DIR}") | |
| model = tf.saved_model.load(LOCAL_MODEL_DIR) | |
| def preprocess_image(img): | |
| """Preprocess the drawn image for prediction""" | |
| # Convert to grayscale and resize | |
| img = img.convert('L') | |
| img = img.resize((28, 28)) | |
| # Convert to numpy array and normalize | |
| img_array = np.array(img) | |
| img_array = img_array.astype('float32') / 255.0 | |
| # Add batch dimension | |
| img_array = np.expand_dims(img_array, axis=0) | |
| # Add channel dimension | |
| img_array = np.expand_dims(img_array, axis=-1) | |
| return img_array | |
| def predict_digit(img): | |
| """Predict digit from drawn image""" | |
| try: | |
| # Preprocess the image | |
| processed_img = preprocess_image(img) | |
| # Make prediction | |
| predictions = model(processed_img) | |
| predicted_digit = tf.argmax(predictions, axis=1).numpy()[0] | |
| # Get confidence scores | |
| confidence_scores = tf.nn.softmax(predictions[0]).numpy() | |
| # Create result string | |
| result = f"Predicted Digit: {predicted_digit}\n\nConfidence Scores:\n" | |
| for i, score in enumerate(confidence_scores): | |
| result += f"Digit {i}: {score:.2%}\n" | |
| return result | |
| except Exception as e: | |
| return f"Error during prediction: {str(e)}" | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=predict_digit, | |
| inputs=gr.Image(type="pil", label="Draw a digit (0-9)"), | |
| outputs=gr.Textbox(label="Prediction Results"), | |
| title="Digit Recognition with CNN", | |
| description=""" | |
| Draw a digit (0-9) in the box below. The model will predict which digit you drew. | |
| Instructions: | |
| 1. Click and drag to draw a digit | |
| 2. Make sure the digit is clear and centered | |
| 3. The model will show the predicted digit and confidence scores | |
| """, | |
| examples=[ | |
| ["examples/0.png"], | |
| ["examples/1.png"], | |
| ["examples/2.png"], | |
| ], | |
| theme=gr.themes.Soft(), | |
| allow_flagging="never" | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.launch() |