shingguy1 commited on
Commit
45a1fc8
·
verified ·
1 Parent(s): 1101f7d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +89 -10
src/streamlit_app.py CHANGED
@@ -1,18 +1,97 @@
1
  import streamlit as st
2
- import io
 
3
  from PIL import Image
 
 
 
 
 
 
4
 
5
- st.set_page_config(page_title="🧪 Upload Debug Test", layout="centered")
6
- st.title("🧪 Upload Debug Test")
 
 
7
 
8
- uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  if uploaded_file is not None:
11
- st.write("Upload received!")
12
  try:
13
- file_bytes = uploaded_file.read()
14
- image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
15
- st.image(image, caption="Image Loaded", use_column_width=True)
16
- st.success("✅ Image loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  except Exception as e:
18
- st.error(f" Failed to open image: {e}")
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import os
4
  from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from transformers import (
7
+ ConvNextForImageClassification,
8
+ AutoTokenizer,
9
+ AutoModelForCausalLM
10
+ )
11
 
12
+ # Set Streamlit UI
13
+ st.set_page_config(page_title="🍽️ Food Nutrition Estimator", page_icon="🥗", layout="centered")
14
+ st.title("🍽️ Food Nutrition Estimator")
15
+ st.markdown("Upload a food image and get nutritional information generated by AI!")
16
 
17
+ # Environment & cache setup
18
+ hf_token = os.getenv("HF_TOKEN")
19
+ cache_dir = "/tmp/cache"
20
+ os.makedirs(cache_dir, exist_ok=True)
21
+ os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
+
23
+ # Transform for ConvNeXt
24
+ manual_transform = transforms.Compose([
25
+ transforms.Resize(224),
26
+ transforms.CenterCrop(196),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29
+ transforms.ConvertImageDtype(torch.float32)
30
+ ])
31
+
32
+ # Sidebar info
33
+ st.sidebar.header("Models Used")
34
+ st.sidebar.markdown("""
35
+ - 🖼️ **Image Classifier**: `shingguy1/food-calorie-convnext`
36
+ - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
37
+ """)
38
+
39
+ # Load models
40
+ @st.cache_resource
41
+ def load_models():
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+ # ConvNeXt for classification
45
+ model_convnext = ConvNextForImageClassification.from_pretrained(
46
+ "shingguy1/food-calorie-convnext",
47
+ cache_dir=cache_dir,
48
+ token=hf_token
49
+ ).to(device)
50
+
51
+ # TinyLlama for nutritional facts
52
+ tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir)
53
+ model_llm = AutoModelForCausalLM.from_pretrained(
54
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
55
+ cache_dir=cache_dir,
56
+ torch_dtype=torch.float32,
57
+ device_map="auto"
58
+ )
59
+
60
+ return model_convnext, tokenizer, model_llm, device
61
+
62
+ model_convnext, tokenizer, model_llm, device = load_models()
63
+
64
+ # Upload image
65
+ uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
66
 
67
  if uploaded_file is not None:
 
68
  try:
69
+ image = Image.open(uploaded_file).convert("RGB")
70
+ st.image(image, caption="Uploaded Image", use_column_width=True)
71
+
72
+ # Predict with ConvNeXt
73
+ input_tensor = manual_transform(image).unsqueeze(0).to(device)
74
+ with torch.no_grad():
75
+ outputs = model_convnext(pixel_values=input_tensor)
76
+ pred_idx = outputs.logits.argmax(-1).item()
77
+ pred_label = model_convnext.config.id2label[pred_idx]
78
+ st.success(f"🍴 Predicted Food: **{pred_label}**")
79
+
80
+ # Generate nutrition caption using TinyLlama
81
+ prompt = f"Give the calories, macros, and nutritional facts of a {pred_label}."
82
+ st.subheader("🧾 Nutrition Information")
83
+ st.write(f"🤖 Prompt: `{prompt}`")
84
+
85
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
86
+ with torch.no_grad():
87
+ output = model_llm.generate(**input_ids, max_new_tokens=100)
88
+ caption = tokenizer.decode(output[0], skip_special_tokens=True)
89
+
90
+ st.info(caption)
91
+
92
  except Exception as e:
93
+ st.error(f"Something went wrong: {e}")
94
+
95
+ # Footer
96
+ st.markdown("---")
97
+ st.markdown("Built with ❤️ using Streamlit and Hugging Face by **shingguy1**")