File size: 7,921 Bytes
eda3213
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
import os
import logging
from typing import Dict, List, Optional
from PIL import Image
from transformers import ViTFeatureExtractor, ViTForImageClassification
import requests
import gradio as gr

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

class NutritionAnalyzer:
    """Main class handling food recognition and nutrition analysis"""
    
    def __init__(self):
        self.model_name = "google/vit-base-patch16-224"
        self.feature_extractor = None
        self.model = None
        self.api_key = os.getenv('NUTRITION_API_KEY', 'Your API Key')
        
    def initialize_models(self):
        """Initialize vision transformer model and feature extractor"""
        try:
            self.feature_extractor = ViTFeatureExtractor.from_pretrained(self.model_name)
            self.model = ViTForImageClassification.from_pretrained(self.model_name)
            logger.info("Models initialized successfully")
        except Exception as e:
            logger.error(f"Model initialization failed: {str(e)}")
            raise

    def identify_food(self, image_path: str) -> str:
        """Identify food item from image using Vision Transformer
        
        Args:
            image_path: Path to the input image file
            
        Returns:
            str: Detected food item name
        """
        try:
            image = Image.open(image_path)
            inputs = self.feature_extractor(images=image, return_tensors="pt")
            outputs = self.model(**inputs)
            return self.model.config.id2label[outputs.logits.argmax(-1).item()].split(',')[0]
        except Exception as e:
            logger.error(f"Food identification error: {str(e)}")
            raise

class NutritionAPIHandler:
    """Handles nutrition data retrieval from API"""
    
    BASE_URL = "https://api.api-ninjas.com/v1/nutrition"
    
    def __init__(self, api_key: str):
        self.api_key = api_key
        self.session = requests.Session()
        self.session.headers.update({'X-Api-Key': self.api_key})

    def get_nutrition_data(self, food_name: str) -> Optional[Dict]:
        """Fetch nutrition data from API
        
        Args:
            food_name: Name of food item to query
            
        Returns:
            Optional[Dict]: Nutrition data or None if error occurs
        """
        try:
            response = self.session.get(self.BASE_URL, params={'query': food_name})
            response.raise_for_status()
            return response.json()[0] if response.json() else None
        except Exception as e:
            logger.error(f"API Error: {str(e)}")
            return None

class NutritionFormatter:
    """Formats nutrition data into visual representations"""
    
    @staticmethod
    def create_nutrition_table(data: Dict) -> str:
        """Generate HTML table with nutrition facts
        
        Args:
            data: Dictionary containing nutrition data
            
        Returns:
            str: Formatted HTML table
        """
        if not data:
            return "<p>No nutrition data available</p>"
            
        return f"""
        <div class="nutrition-container">
            <h3>Nutrition Facts for {data.get('name', 'Unknown Food')}</h3>
            <div class="macros">
                <div class="macro calories">
                    <h4>Calories</h4>
                    <div class="value">{data.get('calories', 0)}</div>
                </div>
                <div class="macro protein">
                    <h4>Protein (g)</h4>
                    <div class="value">{data.get('protein_g', 0)}</div>
                </div>
                <div class="macro carbs">
                    <h4>Carbs (g)</h4>
                    <div class="value">{data.get('carbohydrates_total_g', 0)}</div>
                </div>
                <div class="macro fat">
                    <h4>Fat (g)</h4>
                    <div class="value">{data.get('fat_total_g', 0)}</div>
                </div>
            </div>
            <table>
                <tr><th>Nutrient</th><th>Amount</th><th>Daily Value%</th></tr>
                {NutritionFormatter._create_table_rows(data)}
            </table>
        </div>
        """
    
    @staticmethod
    def _create_table_rows(data: Dict) -> str:
        nutrients = [
            ('Saturated Fat (g)', 'fat_saturated_g', 20),
            ('Fiber (g)', 'fiber_g', 25),
            ('Sugar (g)', 'sugar_g', 50),
            ('Sodium (mg)', 'sodium_mg', 2300),
            ('Potassium (mg)', 'potassium_mg', 4700),
            ('Cholesterol (mg)', 'cholesterol_mg', 300)
        ]
        rows = []
        for name, key, dv in nutrients:
            value = data.get(key, 0)
            dv_percent = (value / dv) * 100 if dv else 0
            rows.append(
                f"<tr><td>{name}</td><td>{value}</td>"
                f"<td>{dv_percent:.1f}%</td></tr>"
            )
        return ''.join(rows)

class NutritionAnalyzerApp:
    """Gradio application for nutrition analysis"""
    
    def __init__(self):
        self.analyzer = NutritionAnalyzer()
        self.analyzer.initialize_models()
        self.api_handler = NutritionAPIHandler(self.analyzer.api_key)
        self.example_images = self._load_example_images()
        
    def _load_example_images(self):
        example_paths = [
            "examples/apple.jpg",
            "examples/pizza.jpg",
            "examples/salad.jpg"
        ]
        valid_examples = []
        for path in example_paths:
            if os.path.exists(path):
                valid_examples.append([path])
            else:
                logger.warning(f"Example image not found: {path}")
        return valid_examples if valid_examples else None    
    def analyze_image(self, image_path: str) -> str:
        """Full processing pipeline for image analysis
        
        Args:
            image_path: Path to input image file
            
        Returns:
            str: Formatted nutrition information
        """
        try:
            food_item = self.analyzer.identify_food(image_path)
            nutrition_data = self.api_handler.get_nutrition_data(food_item)
            return self._handle_output(food_item, nutrition_data)
        except Exception as e:
            logger.error(f"Processing error: {str(e)}")
            return "<p class='error'>Error processing request. Please try again.</p>"

    def _handle_output(self, food_item: str, data: Optional[Dict]) -> str:
        if not data:
            return f"<p>No nutrition data found for {food_item}</p>"
        return NutritionFormatter.create_nutrition_table(data)

# Initialize and run Gradio app
if __name__ == "__main__":
    app = NutritionAnalyzerApp()
    
    css = """
    .nutrition-container { max-width: 600px; margin: 20px auto; }
    .macros { display: grid; grid-template-columns: repeat(4, 1fr); gap: 10px; margin: 20px 0; }
    .macro { padding: 15px; border-radius: 8px; text-align: center; }
    .calories { background: #ffd70033; }
    .protein { background: #00ff0033; }
    .carbs { background: #0000ff33; }
    .fat { background: #ff000033; }
    .value { font-size: 1.5em; font-weight: bold; }
    table { width: 100%; margin-top: 20px; }
    th, td { padding: 10px; text-align: left; }
    .error { color: red; font-weight: bold; }
    """
    
    # Update the Gradio interface section in the __main__ block:
with gr.Blocks(css=css) as demo:
    gr.Markdown("# ๐ŸŽ AI Nutrition Analyzer")
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="filepath", label="Upload Food Image")
            submit_btn = gr.Button("Analyze Nutrition")
        with gr.Column():
            output = gr.HTML(label="Nutrition Analysis")
    
    
    submit_btn.click(
        fn=app.analyze_image,
        inputs=image_input,
        outputs=output
    )

    demo.launch()