shingguy1 commited on
Commit
bbac15f
·
verified ·
1 Parent(s): 9043c37

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +55 -44
src/streamlit_app.py CHANGED
@@ -17,12 +17,14 @@ st.markdown("Upload a food image and get nutritional information generated by AI
17
 
18
  # Environment & cache setup
19
  hf_token = os.getenv("HF_TOKEN")
 
 
20
  cache_dir = "/tmp/cache"
21
  os.makedirs(cache_dir, exist_ok=True)
22
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
23
 
24
- # Use ConvNeXt's official image processor
25
- image_processor = ConvNextImageProcessor.from_pretrained("facebook/convnext-base-224-22k")
26
 
27
  # Sidebar info
28
  st.sidebar.header("Models Used")
@@ -35,66 +37,75 @@ st.sidebar.markdown("""
35
  @st.cache_resource
36
  def load_models():
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
38
- model_convnext = ConvNextForImageClassification.from_pretrained(
39
- "shingguy1/food-calorie-convnext",
40
- cache_dir=cache_dir,
41
- token=hf_token
42
- ).to(device)
43
- tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir)
44
- model_llm = AutoModelForCausalLM.from_pretrained(
45
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
46
- cache_dir=cache_dir,
47
- torch_dtype=torch.float32,
48
- device_map="auto"
49
- )
50
- return model_convnext, tokenizer, model_llm, device
51
 
52
- model_convnext, tokenizer, model_llm, device = load_models()
 
 
 
 
 
 
 
 
 
53
 
54
- # Initialize session state for uploaded file
55
- if "uploaded_file" not in st.session_state:
56
- st.session_state.uploaded_file = None
 
 
 
 
 
 
 
 
 
57
 
58
- # Upload image
59
- uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"], key="file_uploader")
60
 
61
- # Update session state when a file is uploaded
62
- if uploaded_file is not None and st.session_state.uploaded_file != uploaded_file:
63
- st.session_state.uploaded_file = uploaded_file
64
- st.write(f"File uploaded: {uploaded_file.name}, Size: {uploaded_file.size} bytes")
65
- else:
66
- st.warning("No file uploaded yet. Please upload a .jpg, .jpeg, or .png image.")
67
 
68
- # Process uploaded file
69
- if st.session_state.uploaded_file is not None:
70
  try:
71
- image = Image.open(st.session_state.uploaded_file).convert("RGB")
72
  st.image(image, caption="Uploaded Image", use_column_width=True)
73
 
74
  # Predict with ConvNeXt
75
- inputs = image_processor(image, return_tensors="pt").to(device)
76
- with torch.no_grad():
77
- outputs = model_convnext(**inputs)
78
- pred_idx = outputs.logits.argmax(-1).item()
79
- pred_label = model_convnext.config.id2label[pred_idx]
80
- st.success(f"🍴 Predicted Food: **{pred_label}**")
 
81
 
82
  # Generate nutrition caption using TinyLlama
83
- prompt = f"Give the calories, macros, and nutritional facts of a {pred_label}."
84
  st.subheader("🧾 Nutrition Information")
85
  st.write(f"🤖 Prompt: `{prompt}`")
86
 
87
- input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
88
- with torch.no_grad():
89
- output = model_llm.generate(**input_ids, max_new_tokens=100)
90
- caption = tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
91
 
92
- st.info(caption)
93
 
94
  except Exception as e:
95
  st.error(f"Something went wrong: {e}")
96
- else:
97
- st.info("Please upload an image to get started.")
98
 
99
  # Footer
100
  st.markdown("---")
 
17
 
18
  # Environment & cache setup
19
  hf_token = os.getenv("HF_TOKEN")
20
+ if not hf_token:
21
+ st.warning("HF_TOKEN not set. Please set the environment variable HF_TOKEN to access private models.")
22
  cache_dir = "/tmp/cache"
23
  os.makedirs(cache_dir, exist_ok=True)
24
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
25
 
26
+ # Use ConvNeXt's official image processor for a compatible model
27
+ image_processor = ConvNextImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
28
 
29
  # Sidebar info
30
  st.sidebar.header("Models Used")
 
37
  @st.cache_resource
38
  def load_models():
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ st.info(f"Using device: {device}")
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ try:
43
+ # ConvNeXt for classification
44
+ model_convnext = ConvNextForImageClassification.from_pretrained(
45
+ "shingguy1/food-calorie-convnext",
46
+ cache_dir=cache_dir,
47
+ token=hf_token
48
+ ).to(device)
49
+ except Exception as e:
50
+ st.error(f"Failed to load ConvNeXt model: {e}")
51
+ st.stop()
52
 
53
+ try:
54
+ # TinyLlama for nutritional facts
55
+ tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir)
56
+ model_llm = AutoModelForCausalLM.from_pretrained(
57
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
58
+ cache_dir=cache_dir,
59
+ torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
60
+ device_map="auto"
61
+ )
62
+ except Exception as e:
63
+ st.error(f"Failed to load TinyLlama model: {e}")
64
+ st.stop()
65
 
66
+ return model_convnext, tokenizer, model_llm, device
 
67
 
68
+ with st.spinner("Loading models..."):
69
+ model_convnext, tokenizer, model_llm, device = load_models()
70
+
71
+ # Upload image
72
+ uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
 
73
 
74
+ if uploaded_file is not None:
 
75
  try:
76
+ image = Image.open(uploaded_file).convert("RGB")
77
  st.image(image, caption="Uploaded Image", use_column_width=True)
78
 
79
  # Predict with ConvNeXt
80
+ with st.spinner("Classifying food..."):
81
+ inputs = image_processor(image, return_tensors="pt").to(device)
82
+ with torch.no_grad():
83
+ outputs = model_convnext(**inputs)
84
+ pred_idx = outputs.logits.argmax(-1).item()
85
+ pred_label = model_convnext.config.id2label[pred_idx]
86
+ st.success(f"🍴 Predicted Food: **{pred_label}**")
87
 
88
  # Generate nutrition caption using TinyLlama
89
+ prompt = f"Provide the calories, protein, fat, and carbs for a typical serving of {pred_label}. Format the response as: 'Calories: X kcal, Protein: Y g, Fat: Z g, Carbs: W g'."
90
  st.subheader("🧾 Nutrition Information")
91
  st.write(f"🤖 Prompt: `{prompt}`")
92
 
93
+ with st.spinner("Generating nutritional facts..."):
94
+ input_ids = tokenizer(prompt, return_tensors="pt").to(device)
95
+ with torch.no_grad():
96
+ output = model_llm.generate(
97
+ **input_ids,
98
+ max_new_tokens=150,
99
+ temperature=0.7,
100
+ top_p=0.9
101
+ )
102
+ caption = tokenizer.decode(output[0], skip_special_tokens=True)
103
+ caption = caption.replace(prompt, "").strip() # Remove prompt if echoed
104
 
105
+ st.info(caption if caption else "No nutritional information generated.")
106
 
107
  except Exception as e:
108
  st.error(f"Something went wrong: {e}")
 
 
109
 
110
  # Footer
111
  st.markdown("---")