mahmoudalrefaey commited on
Commit
2417fa0
·
verified ·
1 Parent(s): e535ad1

Upload 7 files

Browse files
Files changed (7) hide show
  1. .gitignore +79 -0
  2. INSTALLATION.md +185 -0
  3. README.md +194 -0
  4. app.py +125 -0
  5. config.py +50 -0
  6. predict.py +210 -0
  7. requirements.txt +12 -0
.gitignore ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ env/
12
+ build/
13
+ develop-eggs/
14
+ dist/
15
+ downloads/
16
+ eggs/
17
+ .eggs/
18
+ lib/
19
+ lib64/
20
+ parts/
21
+ sdist/
22
+ var/
23
+ *.egg-info/
24
+ .installed.cfg
25
+ *.egg
26
+
27
+ # Virtual environments
28
+ .venv/
29
+ venv/
30
+ ENV/
31
+ env.bak/
32
+ foodvit_env/
33
+
34
+ # PyInstaller
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ .hypothesis/
53
+ .pytest_cache/
54
+
55
+ # Jupyter Notebook
56
+ .ipynb_checkpoints
57
+
58
+ # pyenv
59
+ .python-version
60
+
61
+ # mypy
62
+ .mypy_cache/
63
+ .dmypy.json
64
+
65
+ # VS Code
66
+ .vscode/
67
+
68
+ # Mac
69
+ .DS_Store
70
+
71
+ # Windows
72
+ Thumbs.db
73
+ Desktop.ini
74
+
75
+ # Model weights (optional: comment out if you want to track model files)
76
+ model/*.pth
77
+
78
+ # Sample images (optional: comment out if you want to track sample images)
79
+ assets/samples/*
INSTALLATION.md ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Installation Guide for FoodViT
2
+
3
+ ## Prerequisites
4
+
5
+ - Python 3.8 or higher
6
+ - pip package manager
7
+ - At least 4GB RAM (8GB recommended)
8
+ - GPU support optional but recommended for faster inference
9
+
10
+ ## Installation Steps
11
+
12
+ ### 1. Clone or Download the Project
13
+
14
+ Make sure you have all the project files in your directory:
15
+ - `app.py` - Main application
16
+ - `predict.py` - Command line tool
17
+ - `config.py` - Configuration
18
+ - `requirements.txt` - Dependencies
19
+ - `model/bestViT_PT.pth` - Trained model
20
+ - All utility and interface files
21
+
22
+ ### 2. Create a Virtual Environment (Recommended)
23
+
24
+ ```bash
25
+ # Create virtual environment
26
+ python -m venv foodvit_env
27
+
28
+ # Activate virtual environment
29
+ # On Windows:
30
+ foodvit_env\Scripts\activate
31
+ # On macOS/Linux:
32
+ source foodvit_env/bin/activate
33
+ ```
34
+
35
+ ### 3. Install Dependencies
36
+
37
+ ```bash
38
+ # Install PyTorch first (choose appropriate version for your system)
39
+ # For CPU only:
40
+ pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu
41
+
42
+ # For CUDA (if you have NVIDIA GPU):
43
+ # pip install torch torchvision --index-url https://download.pytorch.org/whl/cu118
44
+
45
+ # Install other dependencies
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+ ### 4. Troubleshooting Dependency Issues
50
+
51
+ If you encounter dependency conflicts, try this step-by-step approach:
52
+
53
+ ```bash
54
+ # 1. Install core dependencies first
55
+ pip install torch torchvision
56
+ pip install transformers==4.28.0
57
+ pip install huggingface-hub==0.15.1
58
+ pip install accelerate==0.20.3
59
+
60
+ # 2. Install image processing libraries
61
+ pip install Pillow opencv-python albumentations
62
+
63
+ # 3. Install Gradio
64
+ pip install gradio==3.35.2
65
+
66
+ # 4. Install other utilities
67
+ pip install numpy scikit-learn datasets
68
+ ```
69
+
70
+ ### 5. Alternative: Use Conda
71
+
72
+ If you prefer conda:
73
+
74
+ ```bash
75
+ # Create conda environment
76
+ conda create -n foodvit python=3.9
77
+ conda activate foodvit
78
+
79
+ # Install PyTorch
80
+ conda install pytorch torchvision -c pytorch
81
+
82
+ # Install other packages
83
+ pip install transformers==4.28.0 huggingface-hub==0.15.1
84
+ pip install gradio==3.35.2
85
+ pip install -r requirements.txt
86
+ ```
87
+
88
+ ## Testing the Installation
89
+
90
+ ### 1. Run Basic Tests
91
+
92
+ ```bash
93
+ python simple_test.py
94
+ ```
95
+
96
+ This should show all tests passing.
97
+
98
+ ### 2. Test the Web Interface
99
+
100
+ ```bash
101
+ python app.py
102
+ ```
103
+
104
+ Then open your browser to `http://localhost:7860`
105
+
106
+ ### 3. Test Command Line Tool
107
+
108
+ ```bash
109
+ # Test help
110
+ python predict.py --help
111
+
112
+ # Test with a sample image (if you have one)
113
+ python predict.py path/to/your/image.jpg
114
+ ```
115
+
116
+ ## Common Issues and Solutions
117
+
118
+ ### Issue: "cannot import name 'split_torch_state_dict_into_shards'"
119
+
120
+ **Solution**: This is a version compatibility issue. Try:
121
+
122
+ ```bash
123
+ pip uninstall huggingface-hub transformers accelerate
124
+ pip install huggingface-hub==0.15.1 transformers==4.28.0 accelerate==0.20.3
125
+ ```
126
+
127
+ ### Issue: CUDA/GPU not working
128
+
129
+ **Solution**:
130
+ 1. Check if you have NVIDIA GPU
131
+ 2. Install appropriate CUDA version
132
+ 3. Install PyTorch with CUDA support
133
+ 4. Or set device to 'cpu' in `config.py`
134
+
135
+ ### Issue: Model file not found
136
+
137
+ **Solution**: Ensure `model/bestViT_PT.pth` exists in the project directory.
138
+
139
+ ### Issue: Memory errors
140
+
141
+ **Solution**:
142
+ 1. Close other applications
143
+ 2. Use CPU instead of GPU
144
+ 3. Reduce batch size in configuration
145
+
146
+ ## System Requirements
147
+
148
+ ### Minimum Requirements
149
+ - Python 3.8+
150
+ - 4GB RAM
151
+ - 500MB disk space
152
+
153
+ ### Recommended Requirements
154
+ - Python 3.9+
155
+ - 8GB RAM
156
+ - NVIDIA GPU with CUDA support
157
+ - 1GB disk space
158
+
159
+ ## Verification
160
+
161
+ After successful installation, you should be able to:
162
+
163
+ 1. ✅ Run `python simple_test.py` without errors
164
+ 2. ✅ Start the web interface with `python app.py`
165
+ 3. ✅ Use command line tool with `python predict.py --help`
166
+ 4. ✅ Upload images and get predictions in the web interface
167
+
168
+ ## Getting Help
169
+
170
+ If you encounter issues:
171
+
172
+ 1. Check the error messages carefully
173
+ 2. Ensure all dependencies are installed correctly
174
+ 3. Try the troubleshooting steps above
175
+ 4. Check if your Python version is compatible
176
+ 5. Verify the model file exists and is not corrupted
177
+
178
+ ## Next Steps
179
+
180
+ Once installation is complete:
181
+
182
+ 1. **Web Interface**: Run `python app.py` and visit `http://localhost:7860`
183
+ 2. **Command Line**: Use `python predict.py` for batch processing
184
+ 3. **Customization**: Edit `config.py` to modify settings
185
+ 4. **Development**: Use the modular structure for extending functionality
README.md ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # FoodViT - Food Classification Application
2
+
3
+ A production-ready food classification application using Vision Transformer (ViT) that can classify images into three categories: **pizza**, **steak**, and **sushi**.
4
+
5
+ ## 🍕 Features
6
+
7
+ - **Web Interface**: Beautiful Gradio web interface for easy image upload and classification
8
+ - **Command Line Tool**: Batch prediction capabilities for processing multiple images
9
+ - **High Accuracy**: Trained Vision Transformer model with excellent performance
10
+ - **Production Ready**: Modular, well-structured codebase with proper error handling
11
+ - **Dynamic Example Images**: Example images are randomly selected from `assets/samples/` at each app launch
12
+ - **Easy Deployment**: Simple setup and configuration
13
+
14
+ ## 📁 Project Structure
15
+
16
+ ```
17
+ FoodViT/
18
+ ├── app.py # Main application entry point
19
+ ├── predict.py # Command-line prediction script
20
+ ├── config.py # Configuration settings
21
+ ├── requirements.txt # Python dependencies
22
+ ├── README.md # This file
23
+ ├── INSTALLATION.md # Installation and troubleshooting guide
24
+ ├── model/
25
+ │ └── bestViT_PT.pth # Trained PyTorch model
26
+ ├── utils/
27
+ │ ├── model_loader.py # Model loading utilities
28
+ │ ├── image_processor.py # Image preprocessing
29
+ │ └── predictor.py # Prediction logic
30
+ ├── interface/
31
+ │ └── gradio_app.py # Gradio web interface
32
+ └── assets/
33
+ └── samples/ # Example images for Gradio interface
34
+ ```
35
+
36
+ ## 🚀 Quick Start
37
+
38
+ ### 1. Installation
39
+
40
+ ```bash
41
+ # Clone the repository
42
+ git clone <repository-url>
43
+ cd FoodViT
44
+
45
+ # Install dependencies
46
+ pip install -r requirements.txt
47
+ ```
48
+
49
+ ### 2. Run the Web Interface
50
+
51
+ ```bash
52
+ # Start the Gradio web interface
53
+ python app.py
54
+ ```
55
+
56
+ The interface will be available at `http://localhost:7860`
57
+
58
+ ### 3. Command Line Usage
59
+
60
+ ```bash
61
+ # Predict a single image
62
+ python predict.py path/to/image.jpg
63
+
64
+ # Predict all images in a directory
65
+ python predict.py path/to/image/directory
66
+
67
+ # Get detailed prediction information
68
+ python predict.py path/to/image.jpg --detailed
69
+
70
+ # Save results to JSON file
71
+ python predict.py path/to/image/directory --output results.json
72
+ ```
73
+
74
+ ## 🎯 Usage Examples
75
+
76
+ ### Web Interface
77
+
78
+ 1. Open your browser and go to `http://localhost:7860`
79
+ 2. Upload an image of pizza, steak, or sushi
80
+ 3. View the prediction results with confidence scores
81
+ 4. Try the example images provided (randomly selected from `assets/samples/`)
82
+
83
+ ### Command Line
84
+
85
+ ```bash
86
+ # Single image prediction
87
+ python predict.py pizza.jpg
88
+ # Output: ✅ pizza.jpg: Pizza (95.23%)
89
+
90
+ # Batch prediction with details
91
+ python predict.py test_images/ --detailed --output results.json
92
+ ```
93
+
94
+ ## ⚙️ Configuration
95
+
96
+ Edit `config.py` to customize:
97
+
98
+ - **Model settings**: Model path, device, image size
99
+ - **Class configuration**: Class names and mappings
100
+ - **Gradio interface**: Title, description, theme
101
+ - **Application settings**: Host, port, debug mode
102
+
103
+ ## 🔧 Advanced Usage
104
+
105
+ ### Custom Model Loading
106
+
107
+ ```python
108
+ from utils.model_loader import ModelLoader
109
+
110
+ # Load custom model
111
+ loader = ModelLoader()
112
+ loader.load_model()
113
+ model = loader.get_model()
114
+ ```
115
+
116
+ ### Image Preprocessing
117
+
118
+ ```python
119
+ from utils.image_processor import ImageProcessor
120
+
121
+ # Preprocess custom image
122
+ processor = ImageProcessor()
123
+ tensor = processor.preprocess_image("path/to/image.jpg")
124
+ ```
125
+
126
+ ### Direct Prediction
127
+
128
+ ```python
129
+ from utils.predictor import FoodPredictor
130
+
131
+ # Initialize and predict
132
+ predictor = FoodPredictor()
133
+ predictor.initialize()
134
+ result = predictor.predict("path/to/image.jpg")
135
+ print(f"Predicted: {result['class']} ({result['confidence']:.2%})")
136
+ ```
137
+
138
+ ## 📊 Model Information
139
+
140
+ - **Architecture**: Vision Transformer (ViT-Base)
141
+ - **Input Size**: 224x224 pixels
142
+ - **Classes**: 3 (pizza, steak, sushi)
143
+ - **Training Data**: Pizza-Steak-Sushi dataset
144
+ - **Framework**: PyTorch with Transformers
145
+
146
+ ## 🛠️ Development
147
+
148
+ ### Project Structure
149
+
150
+ - **`utils/`**: Core utilities for model loading, image processing, and prediction
151
+ - **`interface/`**: Web interface components
152
+ - **`model/`**: Trained model files
153
+ - **`assets/samples/`**: Example images and static assets
154
+
155
+ ### Adding New Features
156
+
157
+ 1. **New Model**: Update `config.py` and `utils/model_loader.py`
158
+ 2. **New Classes**: Modify `config.py` CLASS_CONFIG
159
+ 3. **New Interface**: Create new files in `interface/`
160
+ 4. **New Utilities**: Add to `utils/` directory
161
+
162
+ ## 🧹 Project Cleanliness & GitHub Readiness
163
+
164
+ - All unnecessary files and caches have been removed
165
+ - Example images are dynamically loaded
166
+ - No test or debug files in the repo
167
+ - Ready for production and version control
168
+
169
+ ## 🐛 Troubleshooting
170
+
171
+ See `INSTALLATION.md` for detailed troubleshooting, dependency, and environment tips.
172
+
173
+ ## 📝 License
174
+
175
+ This project is licensed under the MIT License - see the LICENSE file for details.
176
+
177
+ ## 🤝 Contributing
178
+
179
+ 1. Fork the repository
180
+ 2. Create a feature branch
181
+ 3. Make your changes
182
+ 4. Add tests if applicable
183
+ 5. Submit a pull request
184
+
185
+ ## 📞 Support
186
+
187
+ For questions and support:
188
+ - Open an issue on GitHub
189
+ - Check the troubleshooting section
190
+ - Review the configuration options
191
+
192
+ ---
193
+
194
+ **Enjoy classifying your food images! 🍕🥩🍣**
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main application file for FoodViT
3
+ Entry point for the food classification application
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ from pathlib import Path
10
+
11
+ # Add current directory to path for imports
12
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
13
+
14
+ from config import APP_CONFIG
15
+ from interface.gradio_app import launch_interface
16
+ from utils.predictor import predictor
17
+
18
+ def check_dependencies():
19
+ """Check if all required dependencies are available"""
20
+ required_packages = [
21
+ 'torch',
22
+ 'transformers',
23
+ 'gradio',
24
+ 'PIL',
25
+ 'cv2',
26
+ 'albumentations',
27
+ 'numpy'
28
+ ]
29
+
30
+ missing_packages = []
31
+ for package in required_packages:
32
+ try:
33
+ __import__(package)
34
+ except ImportError:
35
+ missing_packages.append(package)
36
+
37
+ if missing_packages:
38
+ print(f"Missing required packages: {', '.join(missing_packages)}")
39
+ print("Please install them using: pip install -r requirements.txt")
40
+ return False
41
+
42
+ return True
43
+
44
+ def check_model_file():
45
+ """Check if the model file exists"""
46
+ model_path = Path("model/bestViT_PT.pth")
47
+ if not model_path.exists():
48
+ print(f"Model file not found: {model_path}")
49
+ print("Please ensure the trained model file is in the model/ directory")
50
+ return False
51
+ return True
52
+
53
+ def main():
54
+ """Main function to run the application"""
55
+
56
+ # Parse command line arguments
57
+ parser = argparse.ArgumentParser(description="FoodViT - Food Classification Application")
58
+ parser.add_argument(
59
+ "--port",
60
+ type=int,
61
+ default=APP_CONFIG["port"],
62
+ help="Port to run the server on"
63
+ )
64
+ parser.add_argument(
65
+ "--host",
66
+ type=str,
67
+ default=APP_CONFIG["host"],
68
+ help="Host to run the server on"
69
+ )
70
+ parser.add_argument(
71
+ "--share",
72
+ action="store_true",
73
+ help="Create a public link for the interface"
74
+ )
75
+ parser.add_argument(
76
+ "--debug",
77
+ action="store_true",
78
+ help="Enable debug mode"
79
+ )
80
+
81
+ args = parser.parse_args()
82
+
83
+ print("=" * 50)
84
+ print("FoodViT - Food Classification Application")
85
+ print("=" * 50)
86
+
87
+ # Check dependencies
88
+ print("Checking dependencies...")
89
+ if not check_dependencies():
90
+ sys.exit(1)
91
+ print("✓ All dependencies available")
92
+
93
+ # Check model file
94
+ print("Checking model file...")
95
+ if not check_model_file():
96
+ sys.exit(1)
97
+ print("✓ Model file found")
98
+
99
+ # Initialize predictor
100
+ print("Initializing model...")
101
+ if not predictor.initialize():
102
+ print("✗ Failed to initialize model")
103
+ sys.exit(1)
104
+ print("✓ Model initialized successfully")
105
+
106
+ # Get model info
107
+ model_info = predictor.get_model_info()
108
+ if "error" not in model_info:
109
+ print(f"✓ Model loaded on {model_info['device']}")
110
+ print(f"✓ Total parameters: {model_info['total_parameters']:,}")
111
+
112
+ print("\nStarting Gradio interface...")
113
+ print(f"Server will be available at: http://{args.host}:{args.port}")
114
+
115
+ try:
116
+ # Launch the interface
117
+ launch_interface()
118
+ except KeyboardInterrupt:
119
+ print("\nApplication stopped by user")
120
+ except Exception as e:
121
+ print(f"Error running application: {e}")
122
+ sys.exit(1)
123
+
124
+ if __name__ == "__main__":
125
+ main()
config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Configuration file for FoodViT project
3
+ Contains all model and application settings
4
+ """
5
+
6
+ import os
7
+ import torch
8
+
9
+ # Model Configuration
10
+ MODEL_CONFIG = {
11
+ "model_path": "model/bestViT_PT.pth",
12
+ "feature_extractor_name": "google/vit-base-patch16-224",
13
+ "num_labels": 3,
14
+ "image_size": 224,
15
+ "device": "cuda" if torch.cuda.is_available() else "cpu"
16
+ }
17
+
18
+ # Class Configuration
19
+ CLASS_CONFIG = {
20
+ "class_names": ["pizza", "steak", "sushi"],
21
+ "id2label": {0: "pizza", 1: "steak", 2: "sushi"},
22
+ "label2id": {"pizza": 0, "steak": 1, "sushi": 2}
23
+ }
24
+
25
+ # Image Processing Configuration
26
+ IMAGE_CONFIG = {
27
+ "target_size": (224, 224),
28
+ "normalize_mean": [0.5, 0.5, 0.5],
29
+ "normalize_std": [0.5, 0.5, 0.5]
30
+ }
31
+
32
+ # Gradio Interface Configuration
33
+ GRADIO_CONFIG = {
34
+ "title": "FoodViT - Food Classification",
35
+ "description": "Upload an image to classify it as pizza, steak, or sushi",
36
+ "examples": [
37
+ ["assets/example_pizza.jpg"],
38
+ ["assets/example_steak.jpg"],
39
+ ["assets/example_sushi.jpg"]
40
+ ],
41
+ "theme": "default"
42
+ }
43
+
44
+ # Application Configuration
45
+ APP_CONFIG = {
46
+ "debug": False,
47
+ "host": "127.0.0.1",
48
+ "port": 7860,
49
+ "share": False
50
+ }
predict.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Command-line prediction script for FoodViT
3
+ Allows batch prediction and testing of the model
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import argparse
9
+ from pathlib import Path
10
+ from PIL import Image
11
+
12
+ # Add current directory to path for imports
13
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
14
+
15
+ from utils.predictor import predictor
16
+ from config import CLASS_CONFIG
17
+
18
+ def predict_single_image(image_path):
19
+ """
20
+ Predict food class for a single image
21
+
22
+ Args:
23
+ image_path: Path to the image file
24
+
25
+ Returns:
26
+ dict: Prediction results
27
+ """
28
+ try:
29
+ # Check if file exists
30
+ if not os.path.exists(image_path):
31
+ return {"error": f"Image file not found: {image_path}"}
32
+
33
+ # Load image
34
+ image = Image.open(image_path)
35
+
36
+ # Make prediction
37
+ result = predictor.predict(image)
38
+
39
+ return result
40
+
41
+ except Exception as e:
42
+ return {"error": f"Error processing {image_path}: {str(e)}"}
43
+
44
+ def predict_batch_images(image_dir):
45
+ """
46
+ Predict food classes for all images in a directory
47
+
48
+ Args:
49
+ image_dir: Directory containing images
50
+
51
+ Returns:
52
+ list: List of prediction results
53
+ """
54
+ results = []
55
+
56
+ # Supported image extensions
57
+ image_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
58
+
59
+ try:
60
+ # Get all image files in directory
61
+ image_files = [
62
+ f for f in os.listdir(image_dir)
63
+ if Path(f).suffix.lower() in image_extensions
64
+ ]
65
+
66
+ if not image_files:
67
+ print(f"No image files found in {image_dir}")
68
+ return results
69
+
70
+ print(f"Found {len(image_files)} image files")
71
+
72
+ # Process each image
73
+ for i, filename in enumerate(image_files, 1):
74
+ image_path = os.path.join(image_dir, filename)
75
+ print(f"Processing {i}/{len(image_files)}: {filename}")
76
+
77
+ result = predict_single_image(image_path)
78
+ result['filename'] = filename
79
+ results.append(result)
80
+
81
+ return results
82
+
83
+ except Exception as e:
84
+ print(f"Error processing directory {image_dir}: {str(e)}")
85
+ return results
86
+
87
+ def print_results(results, detailed=False):
88
+ """
89
+ Print prediction results in a formatted way
90
+
91
+ Args:
92
+ results: Single result dict or list of results
93
+ detailed: Whether to print detailed information
94
+ """
95
+ if isinstance(results, dict):
96
+ results = [results]
97
+
98
+ for result in results:
99
+ if "error" in result:
100
+ filename = result.get('filename', 'Unknown')
101
+ print(f"❌ {filename}: {result['error']}")
102
+ continue
103
+
104
+ if not result.get("success", False):
105
+ filename = result.get('filename', 'Unknown')
106
+ print(f"❌ {filename}: Prediction failed")
107
+ continue
108
+
109
+ # Extract information
110
+ filename = result.get('filename', 'Image')
111
+ predicted_class = result["class"]
112
+ confidence = result["confidence"]
113
+
114
+ # Print basic result
115
+ print(f"✅ {filename}: {predicted_class.title()} ({confidence:.2%})")
116
+
117
+ # Print detailed information if requested
118
+ if detailed:
119
+ print(f" Class ID: {result['class_id']}")
120
+ print(" All probabilities:")
121
+ for class_name, prob in result["probabilities"].items():
122
+ print(f" - {class_name.title()}: {prob:.2%}")
123
+ print()
124
+
125
+ def main():
126
+ """Main function for command-line prediction"""
127
+
128
+ parser = argparse.ArgumentParser(description="FoodViT - Command Line Prediction")
129
+ parser.add_argument(
130
+ "input",
131
+ help="Image file path or directory containing images"
132
+ )
133
+ parser.add_argument(
134
+ "--detailed",
135
+ action="store_true",
136
+ help="Show detailed prediction information"
137
+ )
138
+ parser.add_argument(
139
+ "--output",
140
+ type=str,
141
+ help="Output file to save results (JSON format)"
142
+ )
143
+
144
+ args = parser.parse_args()
145
+
146
+ print("FoodViT - Command Line Prediction")
147
+ print("=" * 40)
148
+
149
+ # Initialize predictor
150
+ print("Initializing model...")
151
+ if not predictor.initialize():
152
+ print("Failed to initialize model")
153
+ sys.exit(1)
154
+ print("✓ Model initialized successfully")
155
+
156
+ # Check if input is file or directory
157
+ input_path = Path(args.input)
158
+
159
+ if input_path.is_file():
160
+ # Single image prediction
161
+ print(f"Predicting single image: {args.input}")
162
+ result = predict_single_image(args.input)
163
+ print_results([result], args.detailed)
164
+ results = [result]
165
+
166
+ elif input_path.is_dir():
167
+ # Batch prediction
168
+ print(f"Predicting images in directory: {args.input}")
169
+ results = predict_batch_images(args.input)
170
+ print_results(results, args.detailed)
171
+
172
+ else:
173
+ print(f"Error: {args.input} is not a valid file or directory")
174
+ sys.exit(1)
175
+
176
+ # Save results if output file specified
177
+ if args.output and results:
178
+ try:
179
+ import json
180
+ # Convert numpy types to native Python types for JSON serialization
181
+ json_results = []
182
+ for result in results:
183
+ json_result = {}
184
+ for key, value in result.items():
185
+ if key == 'probabilities':
186
+ json_result[key] = {k: float(v) for k, v in value.items()}
187
+ elif isinstance(value, (int, float, str, bool)):
188
+ json_result[key] = value
189
+ else:
190
+ json_result[key] = str(value)
191
+ json_results.append(json_result)
192
+
193
+ with open(args.output, 'w') as f:
194
+ json.dump(json_results, f, indent=2)
195
+ print(f"Results saved to: {args.output}")
196
+
197
+ except Exception as e:
198
+ print(f"Error saving results: {e}")
199
+
200
+ # Print summary
201
+ successful_predictions = [r for r in results if r.get("success", False)]
202
+ failed_predictions = len(results) - len(successful_predictions)
203
+
204
+ print(f"\nSummary:")
205
+ print(f"Total images: {len(results)}")
206
+ print(f"Successful predictions: {len(successful_predictions)}")
207
+ print(f"Failed predictions: {failed_predictions}")
208
+
209
+ if __name__ == "__main__":
210
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.9.0,<2.0.0
2
+ torchvision>=0.10.0,<0.15.0
3
+ transformers>=4.20.0,<4.30.0
4
+ gradio>=3.50.0,<4.0.0
5
+ Pillow>=8.0.0,<10.0.0
6
+ opencv-python>=4.5.0,<4.8.0
7
+ albumentations>=1.3.0,<1.4.0
8
+ numpy>=1.21.0,<1.25.0
9
+ scikit-learn>=1.0.0,<1.3.0
10
+ datasets>=2.0.0,<2.14.0
11
+ accelerate>=0.20.0,<0.21.0
12
+ huggingface-hub>=0.15.0,<0.16.0