shingguy1 commited on
Commit
33a13f2
·
verified ·
1 Parent(s): bbac15f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +37 -52
src/streamlit_app.py CHANGED
@@ -5,7 +5,6 @@ from PIL import Image
5
  import torchvision.transforms as transforms
6
  from transformers import (
7
  ConvNextForImageClassification,
8
- ConvNextImageProcessor,
9
  AutoTokenizer,
10
  AutoModelForCausalLM
11
  )
@@ -17,14 +16,18 @@ st.markdown("Upload a food image and get nutritional information generated by AI
17
 
18
  # Environment & cache setup
19
  hf_token = os.getenv("HF_TOKEN")
20
- if not hf_token:
21
- st.warning("HF_TOKEN not set. Please set the environment variable HF_TOKEN to access private models.")
22
  cache_dir = "/tmp/cache"
23
  os.makedirs(cache_dir, exist_ok=True)
24
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
25
 
26
- # Use ConvNeXt's official image processor for a compatible model
27
- image_processor = ConvNextImageProcessor.from_pretrained("facebook/convnext-base-224-22k-1k")
 
 
 
 
 
 
28
 
29
  # Sidebar info
30
  st.sidebar.header("Models Used")
@@ -37,36 +40,26 @@ st.sidebar.markdown("""
37
  @st.cache_resource
38
  def load_models():
39
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
- st.info(f"Using device: {device}")
41
 
42
- try:
43
- # ConvNeXt for classification
44
- model_convnext = ConvNextForImageClassification.from_pretrained(
45
- "shingguy1/food-calorie-convnext",
46
- cache_dir=cache_dir,
47
- token=hf_token
48
- ).to(device)
49
- except Exception as e:
50
- st.error(f"Failed to load ConvNeXt model: {e}")
51
- st.stop()
52
-
53
- try:
54
- # TinyLlama for nutritional facts
55
- tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir)
56
- model_llm = AutoModelForCausalLM.from_pretrained(
57
- "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
58
- cache_dir=cache_dir,
59
- torch_dtype=torch.float16 if device.type == "cuda" else torch.float32,
60
- device_map="auto"
61
- )
62
- except Exception as e:
63
- st.error(f"Failed to load TinyLlama model: {e}")
64
- st.stop()
65
 
66
  return model_convnext, tokenizer, model_llm, device
67
 
68
- with st.spinner("Loading models..."):
69
- model_convnext, tokenizer, model_llm, device = load_models()
70
 
71
  # Upload image
72
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
@@ -77,32 +70,24 @@ if uploaded_file is not None:
77
  st.image(image, caption="Uploaded Image", use_column_width=True)
78
 
79
  # Predict with ConvNeXt
80
- with st.spinner("Classifying food..."):
81
- inputs = image_processor(image, return_tensors="pt").to(device)
82
- with torch.no_grad():
83
- outputs = model_convnext(**inputs)
84
- pred_idx = outputs.logits.argmax(-1).item()
85
- pred_label = model_convnext.config.id2label[pred_idx]
86
- st.success(f"🍴 Predicted Food: **{pred_label}**")
87
 
88
  # Generate nutrition caption using TinyLlama
89
- prompt = f"Provide the calories, protein, fat, and carbs for a typical serving of {pred_label}. Format the response as: 'Calories: X kcal, Protein: Y g, Fat: Z g, Carbs: W g'."
90
  st.subheader("🧾 Nutrition Information")
91
  st.write(f"🤖 Prompt: `{prompt}`")
92
 
93
- with st.spinner("Generating nutritional facts..."):
94
- input_ids = tokenizer(prompt, return_tensors="pt").to(device)
95
- with torch.no_grad():
96
- output = model_llm.generate(
97
- **input_ids,
98
- max_new_tokens=150,
99
- temperature=0.7,
100
- top_p=0.9
101
- )
102
- caption = tokenizer.decode(output[0], skip_special_tokens=True)
103
- caption = caption.replace(prompt, "").strip() # Remove prompt if echoed
104
-
105
- st.info(caption if caption else "No nutritional information generated.")
106
 
107
  except Exception as e:
108
  st.error(f"Something went wrong: {e}")
 
5
  import torchvision.transforms as transforms
6
  from transformers import (
7
  ConvNextForImageClassification,
 
8
  AutoTokenizer,
9
  AutoModelForCausalLM
10
  )
 
16
 
17
  # Environment & cache setup
18
  hf_token = os.getenv("HF_TOKEN")
 
 
19
  cache_dir = "/tmp/cache"
20
  os.makedirs(cache_dir, exist_ok=True)
21
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
22
 
23
+ # Transform for ConvNeXt
24
+ manual_transform = transforms.Compose([
25
+ transforms.Resize(224),
26
+ transforms.CenterCrop(196),
27
+ transforms.ToTensor(),
28
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
29
+ transforms.ConvertImageDtype(torch.float32)
30
+ ])
31
 
32
  # Sidebar info
33
  st.sidebar.header("Models Used")
 
40
  @st.cache_resource
41
  def load_models():
42
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
43
 
44
+ # ConvNeXt for classification
45
+ model_convnext = ConvNextForImageClassification.from_pretrained(
46
+ "shingguy1/food-calorie-convnext",
47
+ cache_dir=cache_dir,
48
+ token=hf_token
49
+ ).to(device)
50
+
51
+ # TinyLlama for nutritional facts
52
+ tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0", cache_dir=cache_dir)
53
+ model_llm = AutoModelForCausalLM.from_pretrained(
54
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
55
+ cache_dir=cache_dir,
56
+ torch_dtype=torch.float32,
57
+ device_map="auto"
58
+ )
 
 
 
 
 
 
 
 
59
 
60
  return model_convnext, tokenizer, model_llm, device
61
 
62
+ model_convnext, tokenizer, model_llm, device = load_models()
 
63
 
64
  # Upload image
65
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
 
70
  st.image(image, caption="Uploaded Image", use_column_width=True)
71
 
72
  # Predict with ConvNeXt
73
+ input_tensor = manual_transform(image).unsqueeze(0).to(device)
74
+ with torch.no_grad():
75
+ outputs = model_convnext(pixel_values=input_tensor)
76
+ pred_idx = outputs.logits.argmax(-1).item()
77
+ pred_label = model_convnext.config.id2label[pred_idx]
78
+ st.success(f"🍴 Predicted Food: **{pred_label}**")
 
79
 
80
  # Generate nutrition caption using TinyLlama
81
+ prompt = f"Give the calories, macros, and nutritional facts of a {pred_label}."
82
  st.subheader("🧾 Nutrition Information")
83
  st.write(f"🤖 Prompt: `{prompt}`")
84
 
85
+ input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
86
+ with torch.no_grad():
87
+ output = model_llm.generate(**input_ids, max_new_tokens=100)
88
+ caption = tokenizer.decode(output[0], skip_special_tokens=True)
89
+
90
+ st.info(caption)
 
 
 
 
 
 
 
91
 
92
  except Exception as e:
93
  st.error(f"Something went wrong: {e}")