shingguy1 commited on
Commit
f8ba963
·
verified ·
1 Parent(s): 13c9633

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +3 -80
src/streamlit_app.py CHANGED
@@ -1,87 +1,14 @@
1
- import streamlit as st
2
- import torch
3
- import os
4
- from transformers import ConvNextForImageClassification, T5ForConditionalGeneration, T5Tokenizer
5
- from PIL import Image
6
- import torchvision.transforms as transforms
7
-
8
- # Streamlit page configuration
9
- st.set_page_config(page_title="Food Calorie Estimator", page_icon="🍽️", layout="centered")
10
-
11
- # Get HF_TOKEN from environment (for private repositories)
12
- hf_token = os.getenv("HF_TOKEN")
13
-
14
- # Debug: Check if HF_TOKEN is retrieved
15
- st.write("HF_TOKEN exists:", bool(hf_token))
16
-
17
- # Set cache directory to /tmp/cache
18
- cache_dir = "/tmp/cache"
19
- try:
20
- os.makedirs(cache_dir, exist_ok=True)
21
- os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
- st.write("Cache directory created:", os.path.exists(cache_dir))
23
- except Exception as e:
24
- st.error(f"Failed to create cache directory: {e}")
25
- cache_dir = None # Fallback to default if creation fails
26
-
27
- # Manual preprocessing transform based on shingguy1/food-calorie-convnext preprocessor_config.json
28
- manual_transform = transforms.Compose([
29
- transforms.Resize(224),
30
- transforms.CenterCrop(196), # crop_pct: 0.875 * 224 = 196
31
- transforms.ToTensor(),
32
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
33
- transforms.ConvertImageDtype(torch.float32)
34
- ])
35
-
36
- # Title and description
37
- st.title("🍽️ Food Calorie Estimator")
38
- st.markdown("""
39
- Upload an image of your food, and our AI will identify the food type and estimate its calorie content!
40
- This app uses a fine-tuned ConvNeXt model for food classification and a T5 model for calorie estimation.
41
- """)
42
-
43
- # Sidebar with model information
44
- st.sidebar.header("About the Models")
45
- st.sidebar.markdown("""
46
- - **Food Classification**: ConvNeXt (`shingguy1/food-calorie-convnext`) trained on Food-101 dataset.
47
- - **Calorie Estimation**: T5 (`shingguy1/food-calorie-t5`) fine-tuned on a synthetic dataset of 40 food items for calorie prediction.
48
- - **Classes**: Pizza, Hamburger, Sushi, Salad, Pasta, Ice Cream, Fried Rice, Tacos, Steak, Chocolate Cake
49
- - **Hosted on**: Hugging Face Spaces
50
- """)
51
-
52
- # Initialize models
53
- @st.cache_resource
54
- def load_models():
55
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
56
- try:
57
- # Load ConvNeXt model
58
- model_convnext = ConvNextForImageClassification.from_pretrained('shingguy1/food-calorie-convnext', cache_dir=cache_dir, token=hf_token).to(device)
59
-
60
- # Load T5 model and tokenizer
61
- try:
62
- tokenizer = T5Tokenizer.from_pretrained('shingguy1/food-calorie-t5', cache_dir=cache_dir, token=hf_token)
63
- model_t5 = T5ForConditionalGeneration.from_pretrained('shingguy1/food-calorie-t5', cache_dir=cache_dir, token=hf_token).to(device)
64
- except Exception as e:
65
- st.error(f"Failed to load T5 model: {e}. Calorie estimation will be skipped.")
66
- tokenizer, model_t5 = None, None
67
-
68
- return model_convnext, tokenizer, model_t5, device
69
- except Exception as e:
70
- st.error(f"Error loading models: {e}")
71
- return None, None, None, None
72
-
73
- model_convnext, tokenizer, model_t5, device = load_models()
74
-
75
  # Image upload
76
  uploaded_file = st.file_uploader("Choose a food image...", type=["jpg", "jpeg", "png"])
77
 
78
  if uploaded_file is not None:
79
  st.write(f"Uploaded file name: {uploaded_file.name}")
80
  st.write(f"File size: {uploaded_file.size} bytes")
 
81
  try:
82
  image = Image.open(uploaded_file).convert('RGB')
83
  st.image(image, caption="Uploaded Image", use_column_width=True)
84
- st.write("Image loaded successfully.")
85
 
86
  # Process image and predict
87
  if model_convnext is not None:
@@ -117,8 +44,4 @@ if uploaded_file is not None:
117
  else:
118
  st.error("ConvNeXt model not loaded.")
119
  except Exception as e:
120
- st.error(f"Error processing uploaded image: {str(e)}")
121
-
122
- # Footer
123
- st.markdown("---")
124
- st.markdown("Built with ❤️ using Streamlit and Hugging Face by shingguy1")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  # Image upload
2
  uploaded_file = st.file_uploader("Choose a food image...", type=["jpg", "jpeg", "png"])
3
 
4
  if uploaded_file is not None:
5
  st.write(f"Uploaded file name: {uploaded_file.name}")
6
  st.write(f"File size: {uploaded_file.size} bytes")
7
+ st.write(f"File type: {uploaded_file.type}")
8
  try:
9
  image = Image.open(uploaded_file).convert('RGB')
10
  st.image(image, caption="Uploaded Image", use_column_width=True)
11
+ st.write("Image loaded successfully from memory.")
12
 
13
  # Process image and predict
14
  if model_convnext is not None:
 
44
  else:
45
  st.error("ConvNeXt model not loaded.")
46
  except Exception as e:
47
+ st.error(f"Error processing uploaded image: {str(e)}")