shingguy1 commited on
Commit
704ef19
·
verified ·
1 Parent(s): 33a13f2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +10 -89
src/streamlit_app.py CHANGED
@@ -1,97 +1,18 @@
1
  import streamlit as st
2
- import torch
3
- import os
4
  from PIL import Image
5
- import torchvision.transforms as transforms
6
- from transformers import (
7
- ConvNextForImageClassification,
8
- AutoTokenizer,
9
- AutoModelForCausalLM
10
- )
11
 
12
- # Set Streamlit UI
13
- st.set_page_config(page_title="🍽️ Food Nutrition Estimator", page_icon="🥗", layout="centered")
14
- st.title("🍽️ Food Nutrition Estimator")
15
- st.markdown("Upload a food image and get nutritional information generated by AI!")
16
 
17
- # Environment & cache setup
18
- hf_token = os.getenv("HF_TOKEN")
19
- cache_dir = "/tmp/cache"
20
- os.makedirs(cache_dir, exist_ok=True)
21
- os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
-
23
- # Transform for ConvNeXt
24
- manual_transform = transforms.Compose([
25
- transforms.Resize(224),
26
- transforms.CenterCrop(196),
27
- transforms.ToTensor(),
28
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29
- transforms.ConvertImageDtype(torch.float32)
30
- ])
31
-
32
- # Sidebar info
33
- st.sidebar.header("Models Used")
34
- st.sidebar.markdown("""
35
- - 🖼️ **Image Classifier**: `shingguy1/food-calorie-convnext`
36
- - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
37
- """)
38
-
39
- # Load models
40
- @st.cache_resource
41
- def load_models():
42
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
43
-
44
- # ConvNeXt for classification
45
- model_convnext = ConvNextForImageClassification.from_pretrained(
46
- "shingguy1/food-calorie-convnext",
47
- cache_dir=cache_dir,
48
- token=hf_token
49
- ).to(device)
50
-
51
- # TinyLlama for nutritional facts
52
- tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir)
53
- model_llm = AutoModelForCausalLM.from_pretrained(
54
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
55
- cache_dir=cache_dir,
56
- torch_dtype=torch.float32,
57
- device_map="auto"
58
- )
59
-
60
- return model_convnext, tokenizer, model_llm, device
61
-
62
- model_convnext, tokenizer, model_llm, device = load_models()
63
-
64
- # Upload image
65
- uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
66
 
67
  if uploaded_file is not None:
 
68
  try:
69
- image = Image.open(uploaded_file).convert("RGB")
70
- st.image(image, caption="Uploaded Image", use_column_width=True)
71
-
72
- # Predict with ConvNeXt
73
- input_tensor = manual_transform(image).unsqueeze(0).to(device)
74
- with torch.no_grad():
75
- outputs = model_convnext(pixel_values=input_tensor)
76
- pred_idx = outputs.logits.argmax(-1).item()
77
- pred_label = model_convnext.config.id2label[pred_idx]
78
- st.success(f"🍴 Predicted Food: **{pred_label}**")
79
-
80
- # Generate nutrition caption using TinyLlama
81
- prompt = f"Give the calories, macros, and nutritional facts of a {pred_label}."
82
- st.subheader("🧾 Nutrition Information")
83
- st.write(f"🤖 Prompt: `{prompt}`")
84
-
85
- input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
86
- with torch.no_grad():
87
- output = model_llm.generate(**input_ids, max_new_tokens=100)
88
- caption = tokenizer.decode(output[0], skip_special_tokens=True)
89
-
90
- st.info(caption)
91
-
92
  except Exception as e:
93
- st.error(f"Something went wrong: {e}")
94
-
95
- # Footer
96
- st.markdown("---")
97
- st.markdown("Built with ❤️ using Streamlit and Hugging Face by **shingguy1**")
 
1
  import streamlit as st
2
+ import io
 
3
  from PIL import Image
 
 
 
 
 
 
4
 
5
+ st.set_page_config(page_title="🧪 Upload Debug Test", layout="centered")
6
+ st.title("🧪 Upload Debug Test")
 
 
7
 
8
+ uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
  if uploaded_file is not None:
11
+ st.write("Upload received!")
12
  try:
13
+ file_bytes = uploaded_file.read()
14
+ image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
15
+ st.image(image, caption="Image Loaded", use_column_width=True)
16
+ st.success("✅ Image loaded successfully.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  except Exception as e:
18
+ st.error(f" Failed to open image: {e}")