shingguy1 commited on
Commit
9129b6f
·
verified ·
1 Parent(s): 94cfed3

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +33 -37
src/streamlit_app.py CHANGED
@@ -5,36 +5,35 @@ from PIL import Image
5
  import torchvision.transforms as transforms
6
  from transformers import (
7
  ConvNextForImageClassification,
8
- Blip2Processor,
9
- Blip2ForConditionalGeneration
10
  )
11
 
12
- # Set Streamlit page config
13
  st.set_page_config(page_title="🍽️ Food Nutrition Estimator", page_icon="🥗", layout="centered")
 
 
14
 
15
- # Use Hugging Face token (for private models if needed)
16
  hf_token = os.getenv("HF_TOKEN")
17
-
18
- # Set Hugging Face cache directory
19
  cache_dir = "/tmp/cache"
20
  os.makedirs(cache_dir, exist_ok=True)
21
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
 
23
- # Manual transform to match ConvNeXt's preprocessor config
24
  manual_transform = transforms.Compose([
25
  transforms.Resize(224),
26
- transforms.CenterCrop(196), # crop_pct: 0.875 * 224
27
  transforms.ToTensor(),
28
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29
  transforms.ConvertImageDtype(torch.float32)
30
  ])
31
 
32
- # Sidebar information
33
  st.sidebar.header("Models Used")
34
  st.sidebar.markdown("""
35
- - 🤖 **Classifier**: `shingguy1/food-calorie-convnext` (ConvNeXt)
36
- - 🧠 **Captioner**: `Salesforce/blip2-flan-t5-xl` (BLIP-2)
37
- - 📝 **Description**: Automatically generates nutritional facts based on food image
38
  """)
39
 
40
  # Load models
@@ -42,31 +41,27 @@ st.sidebar.markdown("""
42
  def load_models():
43
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44
 
45
- # Load ConvNeXt
46
  model_convnext = ConvNextForImageClassification.from_pretrained(
47
- "shingguy1/food-calorie-convnext", cache_dir=cache_dir, token=hf_token
 
 
48
  ).to(device)
49
 
50
- # Load BLIP-2
51
- blip_processor = Blip2Processor.from_pretrained(
52
- "Salesforce/blip2-flan-t5-xl", cache_dir=cache_dir
53
- )
54
- blip_model = Blip2ForConditionalGeneration.from_pretrained(
55
- "Salesforce/blip2-flan-t5-xl",
56
  cache_dir=cache_dir,
57
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
58
- device_map="auto" if torch.cuda.is_available() else None
59
  )
60
 
61
- return model_convnext, blip_processor, blip_model, device
62
-
63
- model_convnext, blip_processor, blip_model, device = load_models()
64
 
65
- # Main interface
66
- st.title("🍽️ Food Nutrition Estimator")
67
- st.markdown("Upload a food image and get a nutrition description generated by AI!")
68
 
69
- # File uploader
70
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
71
 
72
  if uploaded_file is not None:
@@ -74,7 +69,7 @@ if uploaded_file is not None:
74
  image = Image.open(uploaded_file).convert("RGB")
75
  st.image(image, caption="Uploaded Image", use_column_width=True)
76
 
77
- # Classification with ConvNeXt
78
  input_tensor = manual_transform(image).unsqueeze(0).to(device)
79
  with torch.no_grad():
80
  outputs = model_convnext(pixel_values=input_tensor)
@@ -82,15 +77,16 @@ if uploaded_file is not None:
82
  pred_label = model_convnext.config.id2label[pred_idx]
83
  st.success(f"🍴 Predicted Food: **{pred_label}**")
84
 
85
- # Caption generation with BLIP-2
86
- st.subheader("🧾 Nutritional Facts (via BLIP-2)")
87
- prompt = f"Describe the nutritional facts and calories of {pred_label}"
88
- inputs = blip_processor(image, text=prompt, return_tensors="pt").to(device)
89
 
 
90
  with torch.no_grad():
91
- output = blip_model.generate(**inputs, max_new_tokens=100)
 
92
 
93
- caption = blip_processor.decode(output[0], skip_special_tokens=True)
94
  st.info(caption)
95
 
96
  except Exception as e:
@@ -98,4 +94,4 @@ if uploaded_file is not None:
98
 
99
  # Footer
100
  st.markdown("---")
101
- st.markdown("Built with ❤️ using Streamlit and Hugging Face by **shingguy1**")
 
5
  import torchvision.transforms as transforms
6
  from transformers import (
7
  ConvNextForImageClassification,
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM
10
  )
11
 
12
+ # Set Streamlit UI
13
  st.set_page_config(page_title="🍽️ Food Nutrition Estimator", page_icon="🥗", layout="centered")
14
+ st.title("🍽️ Food Nutrition Estimator")
15
+ st.markdown("Upload a food image and get nutritional information generated by AI!")
16
 
17
+ # Environment & cache setup
18
  hf_token = os.getenv("HF_TOKEN")
 
 
19
  cache_dir = "/tmp/cache"
20
  os.makedirs(cache_dir, exist_ok=True)
21
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
 
23
+ # Transform for ConvNeXt
24
  manual_transform = transforms.Compose([
25
  transforms.Resize(224),
26
+ transforms.CenterCrop(196),
27
  transforms.ToTensor(),
28
  transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29
  transforms.ConvertImageDtype(torch.float32)
30
  ])
31
 
32
+ # Sidebar info
33
  st.sidebar.header("Models Used")
34
  st.sidebar.markdown("""
35
+ - 🖼️ **Image Classifier**: `shingguy1/food-calorie-convnext`
36
+ - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
 
37
  """)
38
 
39
  # Load models
 
41
  def load_models():
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
 
44
+ # ConvNeXt for classification
45
  model_convnext = ConvNextForImageClassification.from_pretrained(
46
+ "shingguy1/food-calorie-convnext",
47
+ cache_dir=cache_dir,
48
+ token=hf_token
49
  ).to(device)
50
 
51
+ # TinyLlama for nutritional facts
52
+ tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir)
53
+ model_llm = AutoModelForCausalLM.from_pretrained(
54
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
 
 
55
  cache_dir=cache_dir,
56
+ torch_dtype=torch.float32,
57
+ device_map="auto"
58
  )
59
 
60
+ return model_convnext, tokenizer, model_llm, device
 
 
61
 
62
+ model_convnext, tokenizer, model_llm, device = load_models()
 
 
63
 
64
+ # Upload image
65
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
66
 
67
  if uploaded_file is not None:
 
69
  image = Image.open(uploaded_file).convert("RGB")
70
  st.image(image, caption="Uploaded Image", use_column_width=True)
71
 
72
+ # Predict with ConvNeXt
73
  input_tensor = manual_transform(image).unsqueeze(0).to(device)
74
  with torch.no_grad():
75
  outputs = model_convnext(pixel_values=input_tensor)
 
77
  pred_label = model_convnext.config.id2label[pred_idx]
78
  st.success(f"🍴 Predicted Food: **{pred_label}**")
79
 
80
+ # Generate nutrition caption using TinyLlama
81
+ prompt = f"Give the calories, macros, and nutritional facts of a {pred_label}."
82
+ st.subheader("🧾 Nutrition Information")
83
+ st.write(f"🤖 Prompt: `{prompt}`")
84
 
85
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
86
  with torch.no_grad():
87
+ output = model_llm.generate(**input_ids, max_new_tokens=100)
88
+ caption = tokenizer.decode(output[0], skip_special_tokens=True)
89
 
 
90
  st.info(caption)
91
 
92
  except Exception as e:
 
94
 
95
  # Footer
96
  st.markdown("---")
97
+ st.markdown("Built with ❤️ using Streamlit and Hugging Face by **shingguy1**")