elhamb commited on
Commit
848ad3f
·
verified ·
1 Parent(s): ea08835

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -11
app.py CHANGED
@@ -1,11 +1,92 @@
1
- from transformers import AutoModelForQuestionAnswering, AutoTokenizer, pipeline
2
- import gradio as grad
3
- import ast
4
- mdl_name = "deepset/roberta-base-squad2"
5
- my_pipeline = pipeline('question-answering', model=mdl_name, tokenizer=mdl_name)
6
- def answer_question(question,context):
7
- text= "{"+"'question': '"+question+"','context': '"+context+"'}"
8
- di=ast.literal_eval(text)
9
- response = my_pipeline(di)
10
- return response
11
- grad.Interface(answer_question, inputs=["text","text"], outputs="text").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import tensorflow as tf
3
+ from tensorflow import keras
4
+ import numpy as np
5
+ from PIL import Image
6
+
7
+ # --- Configuration ---
8
+ MODEL_PATH = "cats-vs-dogs-finetuned.keras"
9
+ IMAGE_SIZE = (150, 150) # Adjust this to match the input size your model expects!
10
+ CLASS_LABELS = ['Cat', 'Dog']
11
+
12
+ # --- Load the Model ---
13
+ # We load the Keras model. Hugging Face Spaces will automatically find this file
14
+ # if you upload it to your repository.
15
+ try:
16
+ model = keras.models.load_model(MODEL_PATH)
17
+ print(f"Model loaded successfully from {MODEL_PATH}")
18
+ except Exception as e:
19
+ # If the model fails to load (e.g., during initial setup before it's uploaded),
20
+ # we use a placeholder function. This helps the app start.
21
+ print(f"Error loading model: {e}. Using a placeholder function.")
22
+ model = None
23
+
24
+ # --- Prediction Function ---
25
+ def predict_image(input_img_pil):
26
+ """
27
+ Predicts the class (Cat or Dog) given a PIL Image object.
28
+
29
+ Args:
30
+ input_img_pil: A PIL Image object received from Gradio's Image input.
31
+
32
+ Returns:
33
+ A dictionary of class labels and their probabilities (for Gradio's Label output).
34
+ """
35
+ if model is None:
36
+ # Placeholder behavior if model loading failed
37
+ return {"Error": 1.0}
38
+
39
+ # 1. Preprocessing: Resize and convert to NumPy array
40
+ img_resized = input_img_pil.resize(IMAGE_SIZE)
41
+ img_array = keras.preprocessing.image.img_to_array(img_resized)
42
+
43
+ # 2. Rescaling and Batch dimension:
44
+ # Keras models usually expect input shapes like (Batch_Size, Height, Width, Channels)
45
+ # and often expect pixel values to be normalized (e.g., 0-1 range).
46
+ # Please adjust the normalization based on how your model was trained!
47
+ img_array = img_array / 255.0 # Common normalization step
48
+ img_array = np.expand_dims(img_array, axis=0) # Add batch dimension
49
+
50
+ # 3. Prediction
51
+ predictions = model.predict(img_array)[0] # Get the single prediction result
52
+
53
+ # 4. Format the output for Gradio's Label component
54
+ # The output is expected to be a dictionary: {'label': probability, ...}
55
+
56
+ # Assuming predictions is a 2-element array: [prob_cat, prob_dog]
57
+ output_dict = {
58
+ CLASS_LABELS[0]: float(predictions[0]),
59
+ CLASS_LABELS[1]: float(predictions[1])
60
+ }
61
+
62
+ return output_dict
63
+
64
+
65
+ # --- Gradio Interface Setup ---
66
+
67
+ # Define the input component (Image) and output component (Label)
68
+ image_input = gr.Image(type="pil", label="Upload a Cat or Dog Image")
69
+ label_output = gr.Label(num_top_classes=2, label="Prediction")
70
+
71
+ # Example images for users to try (place these in your Space if you use them)
72
+ examples = [
73
+ # To use these, you would need to upload files named 'example_cat.jpg' and 'example_dog.jpg'
74
+ # 'example_cat.jpg',
75
+ # 'example_dog.jpg'
76
+ ]
77
+
78
+ # Create the Gradio interface
79
+ demo = gr.Interface(
80
+ fn=predict_image,
81
+ inputs=image_input,
82
+ outputs=label_output,
83
+ title="Keras Cat vs Dog Classifier",
84
+ description="Upload an image of a cat or dog to see the model's prediction. The model is loaded from cat-vs-dog.keras.",
85
+ theme=gr.themes.Soft(),
86
+ # Optional: Add examples if you upload them
87
+ # examples=examples
88
+ )
89
+
90
+ # Launch the app
91
+ if __name__ == "__main__":
92
+ demo.launch()