shingguy1 commited on
Commit
fda8905
·
verified ·
1 Parent(s): f8ba963

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +96 -1
src/streamlit_app.py CHANGED
@@ -1,3 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Image upload
2
  uploaded_file = st.file_uploader("Choose a food image...", type=["jpg", "jpeg", "png"])
3
 
@@ -44,4 +118,25 @@ if uploaded_file is not None:
44
  else:
45
  st.error("ConvNeXt model not loaded.")
46
  except Exception as e:
47
- st.error(f"Error processing uploaded image: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import os
4
+ from transformers import ConvNextForImageClassification, T5ForConditionalGeneration, T5Tokenizer
5
+ from PIL import Image
6
+ import torchvision.transforms as transforms
7
+
8
+ # Streamlit page configuration
9
+ st.set_page_config(page_title="Food Calorie Estimator", page_icon="🍽️", layout="centered")
10
+
11
+ # Get HF_TOKEN from environment (for private repositories)
12
+ hf_token = os.getenv("HF_TOKEN")
13
+
14
+ # Debug: Check if HF_TOKEN is retrieved
15
+ st.write("HF_TOKEN exists:", bool(hf_token))
16
+
17
+ # Set cache directory to /tmp/cache
18
+ cache_dir = "/tmp/cache"
19
+ try:
20
+ os.makedirs(cache_dir, exist_ok=True)
21
+ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
+ st.write("Cache directory created:", os.path.exists(cache_dir))
23
+ except Exception as e:
24
+ st.error(f"Failed to create cache directory: {e}")
25
+ cache_dir = None # Fallback to default if creation fails
26
+
27
+ # Manual preprocessing transform based on shingguy1/food-calorie-convnext preprocessor_config.json
28
+ manual_transform = transforms.Compose([
29
+ transforms.Resize(224),
30
+ transforms.CenterCrop(196), # crop_pct: 0.875 * 224 = 196
31
+ transforms.ToTensor(),
32
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
+ transforms.ConvertImageDtype(torch.float32)
34
+ ])
35
+
36
+ # Title and description
37
+ st.title("🍽️ Food Calorie Estimator")
38
+ st.markdown("""
39
+ Upload an image of your food, and our AI will identify the food type and estimate its calorie content!
40
+ This app uses a fine-tuned ConvNeXt model for food classification and a T5 model for calorie estimation.
41
+ """)
42
+
43
+ # Sidebar with model information
44
+ st.sidebar.header("About the Models")
45
+ st.sidebar.markdown("""
46
+ - **Food Classification**: ConvNeXt (`shingguy1/food-calorie-convnext`) trained on Food-101 dataset.
47
+ - **Calorie Estimation**: T5 (`shingguy1/food-calorie-t5`) fine-tuned on a synthetic dataset of 40 food items for calorie prediction.
48
+ - **Classes**: Pizza, Hamburger, Sushi, Salad, Pasta, Ice Cream, Fried Rice, Tacos, Steak, Chocolate Cake
49
+ - **Hosted on**: Hugging Face Spaces
50
+ """)
51
+
52
+ # Initialize models
53
+ @st.cache_resource
54
+ def load_models():
55
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56
+ try:
57
+ # Load ConvNeXt model
58
+ model_convnext = ConvNextForImageClassification.from_pretrained('shingguy1/food-calorie-convnext', cache_dir=cache_dir, token=hf_token).to(device)
59
+
60
+ # Load T5 model and tokenizer
61
+ try:
62
+ tokenizer = T5Tokenizer.from_pretrained('shingguy1/food-calorie-t5', cache_dir=cache_dir, token=hf_token)
63
+ model_t5 = T5ForConditionalGeneration.from_pretrained('shingguy1/food-calorie-t5', cache_dir=cache_dir, token=hf_token).to(device)
64
+ except Exception as e:
65
+ st.error(f"Failed to load T5 model: {e}. Calorie estimation will be skipped.")
66
+ tokenizer, model_t5 = None, None
67
+
68
+ return model_convnext, tokenizer, model_t5, device
69
+ except Exception as e:
70
+ st.error(f"Error loading models: {e}")
71
+ return None, None, None, None
72
+
73
+ model_convnext, tokenizer, model_t5, device = load_models()
74
+
75
  # Image upload
76
  uploaded_file = st.file_uploader("Choose a food image...", type=["jpg", "jpeg", "png"])
77
 
 
118
  else:
119
  st.error("ConvNeXt model not loaded.")
120
  except Exception as e:
121
+ st.error(f"Error processing uploaded image: {str(e)}")
122
+
123
+ # Text input fallback
124
+ st.sidebar.header("Alternative Input")
125
+ food_name = st.sidebar.text_input("Or enter a food name (e.g., 'pizza'):", "")
126
+ if food_name:
127
+ try:
128
+ if tokenizer is not None and model_t5 is not None:
129
+ input_text = f"estimate calories: {food_name.lower()}"
130
+ inputs_t5 = tokenizer(input_text, return_tensors='pt', max_length=64, truncation=True).to(device)
131
+ with torch.no_grad():
132
+ outputs_t5 = model_t5.generate(**inputs_t5, max_length=64)
133
+ calorie_estimate = tokenizer.decode(outputs_t5[0], skip_special_tokens=True)
134
+ st.sidebar.success(f"Estimated Calories for {food_name}: **{calorie_estimate}**")
135
+ else:
136
+ st.sidebar.warning("Calorie estimation skipped due to T5 model loading failure.")
137
+ except Exception as e:
138
+ st.sidebar.error(f"Error estimating calories: {str(e)}")
139
+
140
+ # Footer
141
+ st.markdown("---")
142
+ st.markdown("Built with ❤️ using Streamlit and Hugging Face by shingguy1")