mahmoudalrefaey commited on
Commit
3923696
·
verified ·
1 Parent(s): be1da47

Upload gradio_app.py

Browse files
Files changed (1) hide show
  1. interface/gradio_app.py +162 -0
interface/gradio_app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio interface for FoodViT
3
+ Provides a web interface for food classification
4
+ """
5
+
6
+ import gradio as gr
7
+ import sys
8
+ import os
9
+ from PIL import Image
10
+ import numpy as np
11
+ import random
12
+
13
+ # Add parent directory to path for imports
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+
16
+ from config import GRADIO_CONFIG, CLASS_CONFIG
17
+ from utils.predictor import predictor
18
+
19
+ SAMPLES_DIR = "assets/samples"
20
+ def get_random_examples(n=3):
21
+ files = [os.path.join(SAMPLES_DIR, f) for f in os.listdir(SAMPLES_DIR)
22
+ if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".gif"))]
23
+ return [[f] for f in random.sample(files, min(n, len(files)))] if files else []
24
+
25
+ def classify_food(image):
26
+ """
27
+ Classify food in the uploaded image
28
+
29
+ Args:
30
+ image: PIL Image object from Gradio
31
+
32
+ Returns:
33
+ tuple: (predicted_class, confidence, detailed_results)
34
+ """
35
+ if image is None:
36
+ return "No image uploaded", 0.0, "Please upload an image to classify."
37
+
38
+ try:
39
+ # Make prediction
40
+ result = predictor.predict(image)
41
+
42
+ if not result.get("success", False):
43
+ return "Error", 0.0, f"Prediction failed: {result.get('error', 'Unknown error')}"
44
+
45
+ # Extract results
46
+ predicted_class = result["class"]
47
+ confidence = result["confidence"]
48
+
49
+ # Create detailed results string
50
+ detailed_results = f"**Predicted Class:** {predicted_class.title()}\n\n"
51
+ detailed_results += f"**Confidence:** {confidence:.2%}\n\n"
52
+ detailed_results += "**All Class Probabilities:**\n"
53
+
54
+ for class_name, prob in result["probabilities"].items():
55
+ detailed_results += f"- {class_name.title()}: {prob:.2%}\n"
56
+
57
+ return predicted_class.title(), confidence, detailed_results
58
+
59
+ except Exception as e:
60
+ return "Error", 0.0, f"An error occurred: {str(e)}"
61
+
62
+ def create_interface():
63
+ """Create and configure the Gradio interface"""
64
+
65
+ # Initialize predictor
66
+ if not predictor.initialize():
67
+ raise RuntimeError("Failed to initialize predictor")
68
+
69
+ # Create interface
70
+ with gr.Blocks(
71
+ title=GRADIO_CONFIG["title"],
72
+ theme=gr.themes.Soft()
73
+ ) as interface:
74
+
75
+ gr.Markdown(f"# {GRADIO_CONFIG['title']}")
76
+ gr.Markdown(GRADIO_CONFIG["description"])
77
+
78
+ with gr.Row():
79
+ with gr.Column(scale=1):
80
+ # Input section
81
+ gr.Markdown("## Upload Image")
82
+ input_image = gr.Image(
83
+ type="pil",
84
+ label="Upload a food image",
85
+ height=300
86
+ )
87
+
88
+ classify_btn = gr.Button(
89
+ "Classify Food",
90
+ variant="primary",
91
+ size="lg"
92
+ )
93
+
94
+ # Example images
95
+ gr.Markdown("## Example Images")
96
+ gr.Examples(
97
+ examples=get_random_examples(3),
98
+ inputs=input_image,
99
+ label="Try these examples"
100
+ )
101
+
102
+ with gr.Column(scale=1):
103
+ # Output section
104
+ gr.Markdown("## Results")
105
+
106
+ predicted_class = gr.Textbox(
107
+ label="Predicted Food Class",
108
+ interactive=False
109
+ )
110
+
111
+ confidence_score = gr.Slider(
112
+ minimum=0,
113
+ maximum=1,
114
+ value=0,
115
+ label="Confidence Score",
116
+ interactive=False
117
+ )
118
+
119
+ detailed_results = gr.Markdown(
120
+ label="Detailed Results",
121
+ value="Upload an image and click 'Classify Food' to see results."
122
+ )
123
+
124
+ # Model information
125
+ with gr.Accordion("Model Information", open=False):
126
+ model_info = predictor.get_model_info()
127
+ if "error" not in model_info:
128
+ info_text = f"""
129
+ **Device:** {model_info['device']}
130
+ **Total Parameters:** {model_info['total_parameters']:,}
131
+ **Number of Classes:** {model_info['num_classes']}
132
+ **Classes:** {', '.join(model_info['class_names'])}
133
+ """
134
+ else:
135
+ info_text = f"Error loading model info: {model_info['error']}"
136
+
137
+ gr.Markdown(info_text)
138
+
139
+ # Connect button to function
140
+ classify_btn.click(
141
+ fn=classify_food,
142
+ inputs=input_image,
143
+ outputs=[predicted_class, confidence_score, detailed_results]
144
+ )
145
+
146
+ # Auto-classify when image is uploaded
147
+ input_image.change(
148
+ fn=classify_food,
149
+ inputs=input_image,
150
+ outputs=[predicted_class, confidence_score, detailed_results]
151
+ )
152
+
153
+ return interface
154
+
155
+ def launch_interface():
156
+ """Launch the Gradio interface"""
157
+ interface = create_interface()
158
+ # Launch with default configuration for Hugging Face Spaces
159
+ interface.launch(ssr_mode=False)
160
+
161
+ if __name__ == "__main__":
162
+ launch_interface()