NutriVision.AI / app.py
NawazHanzla's picture
Create app.py
eda3213 verified
import os
import logging
from typing import Dict, List, Optional
from PIL import Image
from transformers import ViTFeatureExtractor, ViTForImageClassification
import requests
import gradio as gr
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class NutritionAnalyzer:
"""Main class handling food recognition and nutrition analysis"""
def __init__(self):
self.model_name = "google/vit-base-patch16-224"
self.feature_extractor = None
self.model = None
self.api_key = os.getenv('NUTRITION_API_KEY', 'Your API Key')
def initialize_models(self):
"""Initialize vision transformer model and feature extractor"""
try:
self.feature_extractor = ViTFeatureExtractor.from_pretrained(self.model_name)
self.model = ViTForImageClassification.from_pretrained(self.model_name)
logger.info("Models initialized successfully")
except Exception as e:
logger.error(f"Model initialization failed: {str(e)}")
raise
def identify_food(self, image_path: str) -> str:
"""Identify food item from image using Vision Transformer
Args:
image_path: Path to the input image file
Returns:
str: Detected food item name
"""
try:
image = Image.open(image_path)
inputs = self.feature_extractor(images=image, return_tensors="pt")
outputs = self.model(**inputs)
return self.model.config.id2label[outputs.logits.argmax(-1).item()].split(',')[0]
except Exception as e:
logger.error(f"Food identification error: {str(e)}")
raise
class NutritionAPIHandler:
"""Handles nutrition data retrieval from API"""
BASE_URL = "https://api.api-ninjas.com/v1/nutrition"
def __init__(self, api_key: str):
self.api_key = api_key
self.session = requests.Session()
self.session.headers.update({'X-Api-Key': self.api_key})
def get_nutrition_data(self, food_name: str) -> Optional[Dict]:
"""Fetch nutrition data from API
Args:
food_name: Name of food item to query
Returns:
Optional[Dict]: Nutrition data or None if error occurs
"""
try:
response = self.session.get(self.BASE_URL, params={'query': food_name})
response.raise_for_status()
return response.json()[0] if response.json() else None
except Exception as e:
logger.error(f"API Error: {str(e)}")
return None
class NutritionFormatter:
"""Formats nutrition data into visual representations"""
@staticmethod
def create_nutrition_table(data: Dict) -> str:
"""Generate HTML table with nutrition facts
Args:
data: Dictionary containing nutrition data
Returns:
str: Formatted HTML table
"""
if not data:
return "<p>No nutrition data available</p>"
return f"""
<div class="nutrition-container">
<h3>Nutrition Facts for {data.get('name', 'Unknown Food')}</h3>
<div class="macros">
<div class="macro calories">
<h4>Calories</h4>
<div class="value">{data.get('calories', 0)}</div>
</div>
<div class="macro protein">
<h4>Protein (g)</h4>
<div class="value">{data.get('protein_g', 0)}</div>
</div>
<div class="macro carbs">
<h4>Carbs (g)</h4>
<div class="value">{data.get('carbohydrates_total_g', 0)}</div>
</div>
<div class="macro fat">
<h4>Fat (g)</h4>
<div class="value">{data.get('fat_total_g', 0)}</div>
</div>
</div>
<table>
<tr><th>Nutrient</th><th>Amount</th><th>Daily Value%</th></tr>
{NutritionFormatter._create_table_rows(data)}
</table>
</div>
"""
@staticmethod
def _create_table_rows(data: Dict) -> str:
nutrients = [
('Saturated Fat (g)', 'fat_saturated_g', 20),
('Fiber (g)', 'fiber_g', 25),
('Sugar (g)', 'sugar_g', 50),
('Sodium (mg)', 'sodium_mg', 2300),
('Potassium (mg)', 'potassium_mg', 4700),
('Cholesterol (mg)', 'cholesterol_mg', 300)
]
rows = []
for name, key, dv in nutrients:
value = data.get(key, 0)
dv_percent = (value / dv) * 100 if dv else 0
rows.append(
f"<tr><td>{name}</td><td>{value}</td>"
f"<td>{dv_percent:.1f}%</td></tr>"
)
return ''.join(rows)
class NutritionAnalyzerApp:
"""Gradio application for nutrition analysis"""
def __init__(self):
self.analyzer = NutritionAnalyzer()
self.analyzer.initialize_models()
self.api_handler = NutritionAPIHandler(self.analyzer.api_key)
self.example_images = self._load_example_images()
def _load_example_images(self):
example_paths = [
"examples/apple.jpg",
"examples/pizza.jpg",
"examples/salad.jpg"
]
valid_examples = []
for path in example_paths:
if os.path.exists(path):
valid_examples.append([path])
else:
logger.warning(f"Example image not found: {path}")
return valid_examples if valid_examples else None
def analyze_image(self, image_path: str) -> str:
"""Full processing pipeline for image analysis
Args:
image_path: Path to input image file
Returns:
str: Formatted nutrition information
"""
try:
food_item = self.analyzer.identify_food(image_path)
nutrition_data = self.api_handler.get_nutrition_data(food_item)
return self._handle_output(food_item, nutrition_data)
except Exception as e:
logger.error(f"Processing error: {str(e)}")
return "<p class='error'>Error processing request. Please try again.</p>"
def _handle_output(self, food_item: str, data: Optional[Dict]) -> str:
if not data:
return f"<p>No nutrition data found for {food_item}</p>"
return NutritionFormatter.create_nutrition_table(data)
# Initialize and run Gradio app
if __name__ == "__main__":
app = NutritionAnalyzerApp()
css = """
.nutrition-container { max-width: 600px; margin: 20px auto; }
.macros { display: grid; grid-template-columns: repeat(4, 1fr); gap: 10px; margin: 20px 0; }
.macro { padding: 15px; border-radius: 8px; text-align: center; }
.calories { background: #ffd70033; }
.protein { background: #00ff0033; }
.carbs { background: #0000ff33; }
.fat { background: #ff000033; }
.value { font-size: 1.5em; font-weight: bold; }
table { width: 100%; margin-top: 20px; }
th, td { padding: 10px; text-align: left; }
.error { color: red; font-weight: bold; }
"""
# Update the Gradio interface section in the __main__ block:
with gr.Blocks(css=css) as demo:
gr.Markdown("# ๐ŸŽ AI Nutrition Analyzer")
with gr.Row():
with gr.Column():
image_input = gr.Image(type="filepath", label="Upload Food Image")
submit_btn = gr.Button("Analyze Nutrition")
with gr.Column():
output = gr.HTML(label="Nutrition Analysis")
submit_btn.click(
fn=app.analyze_image,
inputs=image_input,
outputs=output
)
demo.launch()