shingguy1 commited on
Commit
404eb80
·
verified ·
1 Parent(s): 2ce966d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +23 -22
src/streamlit_app.py CHANGED
@@ -1,24 +1,24 @@
1
  import streamlit as st
2
  import torch
3
  import os
 
 
4
  from transformers import (
5
  ConvNextForImageClassification,
6
- BlipProcessor,
7
- BlipForConditionalGeneration
8
  )
9
- from PIL import Image
10
- import torchvision.transforms as transforms
11
 
12
- # Page setup
13
  st.set_page_config(page_title="🍽️ Food Nutrition Estimator", page_icon="🥗", layout="centered")
14
 
15
- # Environment & cache
16
  hf_token = os.getenv("HF_TOKEN")
17
  cache_dir = "/tmp/cache"
18
  os.makedirs(cache_dir, exist_ok=True)
19
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
20
 
21
- # Image preprocessing based on ConvNeXt preprocessor config
22
  manual_transform = transforms.Compose([
23
  transforms.Resize(224),
24
  transforms.CenterCrop(196),
@@ -30,32 +30,33 @@ manual_transform = transforms.Compose([
30
  # Sidebar Info
31
  st.sidebar.header("Model Info")
32
  st.sidebar.markdown("""
33
- - 🔍 **Classifier**: ConvNeXt (`shingguy1/food-calorie-convnext`)
34
- - 🧠 **Captioner**: BLIP (`Salesforce/blip-image-captioning-base`)
35
- - 🧾 **Caption Output**: Calories, macros, and nutritional description
36
  """)
37
 
38
  # Load models
39
  @st.cache_resource
40
  def load_models():
41
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
42
 
43
  model_convnext = ConvNextForImageClassification.from_pretrained(
44
  "shingguy1/food-calorie-convnext", cache_dir=cache_dir, token=hf_token
45
  ).to(device)
46
 
47
- blip_processor = BlipProcessor.from_pretrained(
48
- "Salesforce/blip-image-captioning-base", cache_dir=cache_dir
 
 
 
 
49
  )
50
- blip_model = BlipForConditionalGeneration.from_pretrained(
51
- "Salesforce/blip-image-captioning-base", cache_dir=cache_dir
52
- ).to(device)
53
 
54
  return model_convnext, blip_processor, blip_model, device
55
 
56
  model_convnext, blip_processor, blip_model, device = load_models()
57
 
58
- # Image upload
59
  uploaded_file = st.file_uploader("Upload a food image (jpg/png)...", type=["jpg", "jpeg", "png"])
60
 
61
  if uploaded_file is not None:
@@ -63,7 +64,7 @@ if uploaded_file is not None:
63
  image = Image.open(uploaded_file).convert("RGB")
64
  st.image(image, caption="Uploaded Image", use_column_width=True)
65
 
66
- # Preprocess and predict
67
  input_tensor = manual_transform(image).unsqueeze(0).to(device)
68
  with torch.no_grad():
69
  outputs = model_convnext(pixel_values=input_tensor)
@@ -71,13 +72,13 @@ if uploaded_file is not None:
71
  pred_label = model_convnext.config.id2label[pred_idx]
72
  st.success(f"🍴 Predicted Food: **{pred_label}**")
73
 
74
- # Generate nutrition caption with BLIP
75
- st.subheader("🧾 Nutritional Facts (via BLIP)")
76
  prompt = f"Describe the nutritional facts and calories of {pred_label}"
77
- inputs = blip_processor(image, prompt, return_tensors="pt").to(device)
78
 
79
  with torch.no_grad():
80
- output = blip_model.generate(**inputs, max_length=128)
81
 
82
  caption = blip_processor.decode(output[0], skip_special_tokens=True)
83
  st.info(caption)
 
1
  import streamlit as st
2
  import torch
3
  import os
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
  from transformers import (
7
  ConvNextForImageClassification,
8
+ Blip2Processor,
9
+ Blip2ForConditionalGeneration
10
  )
 
 
11
 
12
+ # Streamlit setup
13
  st.set_page_config(page_title="🍽️ Food Nutrition Estimator", page_icon="🥗", layout="centered")
14
 
15
+ # Environment setup
16
  hf_token = os.getenv("HF_TOKEN")
17
  cache_dir = "/tmp/cache"
18
  os.makedirs(cache_dir, exist_ok=True)
19
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
20
 
21
+ # Manual transform for ConvNeXt
22
  manual_transform = transforms.Compose([
23
  transforms.Resize(224),
24
  transforms.CenterCrop(196),
 
30
  # Sidebar Info
31
  st.sidebar.header("Model Info")
32
  st.sidebar.markdown("""
33
+ - 🤖 **Classifier**: ConvNeXt (`shingguy1/food-calorie-convnext`)
34
+ - 🧠 **Captioner**: BLIP-2 (`Salesforce/blip2-opt-2.7b`)
35
+ - 📋 **Output**: Nutrition facts and calorie descriptions
36
  """)
37
 
38
  # Load models
39
  @st.cache_resource
40
  def load_models():
41
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
42
 
43
  model_convnext = ConvNextForImageClassification.from_pretrained(
44
  "shingguy1/food-calorie-convnext", cache_dir=cache_dir, token=hf_token
45
  ).to(device)
46
 
47
+ blip_processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b", cache_dir=cache_dir)
48
+ blip_model = Blip2ForConditionalGeneration.from_pretrained(
49
+ "Salesforce/blip2-opt-2.7b",
50
+ cache_dir=cache_dir,
51
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
52
+ device_map="auto" if torch.cuda.is_available() else None
53
  )
 
 
 
54
 
55
  return model_convnext, blip_processor, blip_model, device
56
 
57
  model_convnext, blip_processor, blip_model, device = load_models()
58
 
59
+ # Upload image
60
  uploaded_file = st.file_uploader("Upload a food image (jpg/png)...", type=["jpg", "jpeg", "png"])
61
 
62
  if uploaded_file is not None:
 
64
  image = Image.open(uploaded_file).convert("RGB")
65
  st.image(image, caption="Uploaded Image", use_column_width=True)
66
 
67
+ # ConvNeXt classification
68
  input_tensor = manual_transform(image).unsqueeze(0).to(device)
69
  with torch.no_grad():
70
  outputs = model_convnext(pixel_values=input_tensor)
 
72
  pred_label = model_convnext.config.id2label[pred_idx]
73
  st.success(f"🍴 Predicted Food: **{pred_label}**")
74
 
75
+ # BLIP-2 generation
76
+ st.subheader("🧾 Nutritional Facts (via BLIP-2)")
77
  prompt = f"Describe the nutritional facts and calories of {pred_label}"
78
+ inputs = blip_processor(image, text=prompt, return_tensors="pt").to(device)
79
 
80
  with torch.no_grad():
81
+ output = blip_model.generate(**inputs, max_new_tokens=100)
82
 
83
  caption = blip_processor.decode(output[0], skip_special_tokens=True)
84
  st.info(caption)