codewithharsha commited on
Commit
55d489d
·
verified ·
1 Parent(s): b1406d1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -0
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ import numpy as np
4
+ from PIL import Image
5
+ import os
6
+ import io
7
+
8
+ # --- Configuration ---
9
+ PHOTO_SIZE = 224
10
+ MODEL_FILE_NAME = 'vgg_model50.h5' # Make sure this matches your uploaded model file
11
+ CLASS_NAMES = ["Non-Autstic", "Autstic"] # Match indices 0 and 1 from your notebook
12
+
13
+ # --- Load the Keras model ---
14
+ # Ensure the model file is in the same directory in the Space repository
15
+ model_path = os.path.join(os.path.dirname(__file__), MODEL_FILE_NAME)
16
+ if not os.path.exists(model_path):
17
+ # Basic error handling if model is missing
18
+ raise FileNotFoundError(f"Model file '{MODEL_FILE_NAME}' not found. Please upload it to the Space.")
19
+
20
+ try:
21
+ model = tf.keras.models.load_model(model_path)
22
+ print("Model loaded successfully.")
23
+ except Exception as e:
24
+ print(f"Error loading model: {e}")
25
+ # You might want more robust error handling or display in the Gradio interface
26
+ model = None
27
+
28
+ def preprocess_image(pil_image):
29
+ """
30
+ Preprocesses the PIL image object for the VGG16 model.
31
+ Resizes to (PHOTO_SIZE, PHOTO_SIZE), normalizes to [0, 1].
32
+ """
33
+ try:
34
+ # Gradio provides a PIL image object directly
35
+ img = pil_image.convert('RGB') # Ensure image is RGB
36
+ img = img.resize((PHOTO_SIZE, PHOTO_SIZE))
37
+ np_image = np.array(img).astype('float32') / 255.0
38
+ # Expand dimensions to create batch size of 1
39
+ np_image = np.expand_dims(np_image, axis=0)
40
+ print(f"Image preprocessed successfully. Shape: {np_image.shape}")
41
+ return np_image
42
+ except Exception as e:
43
+ print(f"Error preprocessing image: {e}")
44
+ return None
45
+
46
+ def predict_autism(image_input):
47
+ """
48
+ Takes a PIL Image input from Gradio, preprocesses, predicts, and returns the class name.
49
+ """
50
+ if model is None:
51
+ return "Error: Model not loaded." # Or raise an error
52
+
53
+ print(f"Received image of type: {type(image_input)}") # Should be PIL Image
54
+
55
+ # Preprocess the image (Gradio image input provides PIL image)
56
+ processed_image = preprocess_image(image_input)
57
+ if processed_image is None:
58
+ return "Error: Image preprocessing failed."
59
+
60
+ # Make prediction
61
+ print("Making prediction...")
62
+ prediction = model.predict(processed_image)
63
+ predicted_class_index = np.argmax(prediction, axis=1)[0]
64
+ predicted_class_name = CLASS_NAMES[predicted_class_index]
65
+ confidence = float(np.max(prediction)) # Get the confidence score
66
+
67
+ print(f"Prediction result index: {predicted_class_index}, Class: {predicted_class_name}, Confidence: {confidence:.4f}")
68
+
69
+ # Return prediction as a dictionary (Gradio handles JSON conversion for API)
70
+ # The key 'label' often works well with Gradio output components
71
+ # Or return just the string if using a simple Textbox output
72
+ # return {CLASS_NAMES[0]: float(1-confidence), CLASS_NAMES[1]: confidence} # Example for Label output
73
+ return predicted_class_name # Simpler for Textbox output
74
+
75
+ # --- Create Gradio Interface ---
76
+ # Input: Image Upload
77
+ # Output: Textbox to display the predicted class
78
+ # Allow flagging for feedback (optional but good practice)
79
+ # Add title and description
80
+ # Provide example images if available in your Space repo (e.g., in an 'examples' folder)
81
+ # examples_folder = os.path.join(os.path.dirname(__file__), "examples")
82
+ # example_images = [os.path.join(examples_folder, img) for img in os.listdir(examples_folder)] if os.path.exists(examples_folder) else None
83
+
84
+ iface = gr.Interface(
85
+ fn=predict_autism,
86
+ inputs=gr.Image(type="pil", label="Upload Image"), # Input is PIL format
87
+ outputs=gr.Textbox(label="Prediction Result"), # Output is simple text
88
+ # outputs=gr.Label(num_top_classes=2), # Alternative: Label output shows confidences
89
+ title="Autism Classification from Facial Images (VGG16)",
90
+ description="Upload a facial image to classify as Autistic or Non-Autistic using a VGG16 model.",
91
+ allow_flagging="never",
92
+ # examples=example_images # Uncomment if you add example images
93
+ )
94
+
95
+ # --- Launch the Gradio app ---
96
+ # When run on Hugging Face Spaces, it automatically uses the Space's URL
97
+ if __name__ == "__main__":
98
+ iface.launch() # share=True is not needed on Spaces