Spaces:
Sleeping
Sleeping
| import os | |
| import logging | |
| from typing import Dict, List, Optional | |
| from PIL import Image | |
| from transformers import ViTFeatureExtractor, ViTForImageClassification | |
| import requests | |
| import gradio as gr | |
| # Configure logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class NutritionAnalyzer: | |
| """Main class handling food recognition and nutrition analysis""" | |
| def __init__(self): | |
| self.model_name = "google/vit-base-patch16-224" | |
| self.feature_extractor = None | |
| self.model = None | |
| self.api_key = os.getenv('NUTRITION_API_KEY', 'Your API Key') | |
| def initialize_models(self): | |
| """Initialize vision transformer model and feature extractor""" | |
| try: | |
| self.feature_extractor = ViTFeatureExtractor.from_pretrained(self.model_name) | |
| self.model = ViTForImageClassification.from_pretrained(self.model_name) | |
| logger.info("Models initialized successfully") | |
| except Exception as e: | |
| logger.error(f"Model initialization failed: {str(e)}") | |
| raise | |
| def identify_food(self, image_path: str) -> str: | |
| """Identify food item from image using Vision Transformer | |
| Args: | |
| image_path: Path to the input image file | |
| Returns: | |
| str: Detected food item name | |
| """ | |
| try: | |
| image = Image.open(image_path) | |
| inputs = self.feature_extractor(images=image, return_tensors="pt") | |
| outputs = self.model(**inputs) | |
| return self.model.config.id2label[outputs.logits.argmax(-1).item()].split(',')[0] | |
| except Exception as e: | |
| logger.error(f"Food identification error: {str(e)}") | |
| raise | |
| class NutritionAPIHandler: | |
| """Handles nutrition data retrieval from API""" | |
| BASE_URL = "https://api.api-ninjas.com/v1/nutrition" | |
| def __init__(self, api_key: str): | |
| self.api_key = api_key | |
| self.session = requests.Session() | |
| self.session.headers.update({'X-Api-Key': self.api_key}) | |
| def get_nutrition_data(self, food_name: str) -> Optional[Dict]: | |
| """Fetch nutrition data from API | |
| Args: | |
| food_name: Name of food item to query | |
| Returns: | |
| Optional[Dict]: Nutrition data or None if error occurs | |
| """ | |
| try: | |
| response = self.session.get(self.BASE_URL, params={'query': food_name}) | |
| response.raise_for_status() | |
| return response.json()[0] if response.json() else None | |
| except Exception as e: | |
| logger.error(f"API Error: {str(e)}") | |
| return None | |
| class NutritionFormatter: | |
| """Formats nutrition data into visual representations""" | |
| def create_nutrition_table(data: Dict) -> str: | |
| """Generate HTML table with nutrition facts | |
| Args: | |
| data: Dictionary containing nutrition data | |
| Returns: | |
| str: Formatted HTML table | |
| """ | |
| if not data: | |
| return "<p>No nutrition data available</p>" | |
| return f""" | |
| <div class="nutrition-container"> | |
| <h3>Nutrition Facts for {data.get('name', 'Unknown Food')}</h3> | |
| <div class="macros"> | |
| <div class="macro calories"> | |
| <h4>Calories</h4> | |
| <div class="value">{data.get('calories', 0)}</div> | |
| </div> | |
| <div class="macro protein"> | |
| <h4>Protein (g)</h4> | |
| <div class="value">{data.get('protein_g', 0)}</div> | |
| </div> | |
| <div class="macro carbs"> | |
| <h4>Carbs (g)</h4> | |
| <div class="value">{data.get('carbohydrates_total_g', 0)}</div> | |
| </div> | |
| <div class="macro fat"> | |
| <h4>Fat (g)</h4> | |
| <div class="value">{data.get('fat_total_g', 0)}</div> | |
| </div> | |
| </div> | |
| <table> | |
| <tr><th>Nutrient</th><th>Amount</th><th>Daily Value%</th></tr> | |
| {NutritionFormatter._create_table_rows(data)} | |
| </table> | |
| </div> | |
| """ | |
| def _create_table_rows(data: Dict) -> str: | |
| nutrients = [ | |
| ('Saturated Fat (g)', 'fat_saturated_g', 20), | |
| ('Fiber (g)', 'fiber_g', 25), | |
| ('Sugar (g)', 'sugar_g', 50), | |
| ('Sodium (mg)', 'sodium_mg', 2300), | |
| ('Potassium (mg)', 'potassium_mg', 4700), | |
| ('Cholesterol (mg)', 'cholesterol_mg', 300) | |
| ] | |
| rows = [] | |
| for name, key, dv in nutrients: | |
| value = data.get(key, 0) | |
| dv_percent = (value / dv) * 100 if dv else 0 | |
| rows.append( | |
| f"<tr><td>{name}</td><td>{value}</td>" | |
| f"<td>{dv_percent:.1f}%</td></tr>" | |
| ) | |
| return ''.join(rows) | |
| class NutritionAnalyzerApp: | |
| """Gradio application for nutrition analysis""" | |
| def __init__(self): | |
| self.analyzer = NutritionAnalyzer() | |
| self.analyzer.initialize_models() | |
| self.api_handler = NutritionAPIHandler(self.analyzer.api_key) | |
| self.example_images = self._load_example_images() | |
| def _load_example_images(self): | |
| example_paths = [ | |
| "examples/apple.jpg", | |
| "examples/pizza.jpg", | |
| "examples/salad.jpg" | |
| ] | |
| valid_examples = [] | |
| for path in example_paths: | |
| if os.path.exists(path): | |
| valid_examples.append([path]) | |
| else: | |
| logger.warning(f"Example image not found: {path}") | |
| return valid_examples if valid_examples else None | |
| def analyze_image(self, image_path: str) -> str: | |
| """Full processing pipeline for image analysis | |
| Args: | |
| image_path: Path to input image file | |
| Returns: | |
| str: Formatted nutrition information | |
| """ | |
| try: | |
| food_item = self.analyzer.identify_food(image_path) | |
| nutrition_data = self.api_handler.get_nutrition_data(food_item) | |
| return self._handle_output(food_item, nutrition_data) | |
| except Exception as e: | |
| logger.error(f"Processing error: {str(e)}") | |
| return "<p class='error'>Error processing request. Please try again.</p>" | |
| def _handle_output(self, food_item: str, data: Optional[Dict]) -> str: | |
| if not data: | |
| return f"<p>No nutrition data found for {food_item}</p>" | |
| return NutritionFormatter.create_nutrition_table(data) | |
| # Initialize and run Gradio app | |
| if __name__ == "__main__": | |
| app = NutritionAnalyzerApp() | |
| css = """ | |
| .nutrition-container { max-width: 600px; margin: 20px auto; } | |
| .macros { display: grid; grid-template-columns: repeat(4, 1fr); gap: 10px; margin: 20px 0; } | |
| .macro { padding: 15px; border-radius: 8px; text-align: center; } | |
| .calories { background: #ffd70033; } | |
| .protein { background: #00ff0033; } | |
| .carbs { background: #0000ff33; } | |
| .fat { background: #ff000033; } | |
| .value { font-size: 1.5em; font-weight: bold; } | |
| table { width: 100%; margin-top: 20px; } | |
| th, td { padding: 10px; text-align: left; } | |
| .error { color: red; font-weight: bold; } | |
| """ | |
| # Update the Gradio interface section in the __main__ block: | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# ๐ AI Nutrition Analyzer") | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(type="filepath", label="Upload Food Image") | |
| submit_btn = gr.Button("Analyze Nutrition") | |
| with gr.Column(): | |
| output = gr.HTML(label="Nutrition Analysis") | |
| submit_btn.click( | |
| fn=app.analyze_image, | |
| inputs=image_input, | |
| outputs=output | |
| ) | |
| demo.launch() |