shingguy1 commited on
Commit
cdfccf9
·
verified ·
1 Parent(s): 898542f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +26 -23
src/streamlit_app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  from PIL import Image
5
  import torchvision.transforms as transforms
6
  from transformers import (
7
- ConvNextForImageClassification,
8
  AutoTokenizer,
9
  AutoModelForCausalLM
10
  )
@@ -24,10 +24,10 @@ cache_dir = "/tmp/cache"
24
  os.makedirs(cache_dir, exist_ok=True)
25
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
26
 
27
- # Transform for ConvNeXt
28
  manual_transform = transforms.Compose([
29
- transforms.Resize(224),
30
- transforms.CenterCrop(196),
31
  transforms.ToTensor(),
32
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
33
  std=[0.229, 0.224, 0.225]),
@@ -37,7 +37,7 @@ manual_transform = transforms.Compose([
37
  # Sidebar info
38
  st.sidebar.header("Models Used")
39
  st.sidebar.markdown("""
40
- - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_convnext`
41
  - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
42
  """)
43
 
@@ -46,9 +46,9 @@ st.sidebar.markdown("""
46
  def load_models():
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
 
49
- # ConvNeXt for classification
50
- model_convnext = ConvNextForImageClassification.from_pretrained(
51
- "shingguy1/fine_tuned_convnext",
52
  cache_dir=cache_dir,
53
  token=hf_token
54
  ).to(device)
@@ -65,9 +65,9 @@ def load_models():
65
  device_map="auto"
66
  )
67
 
68
- return model_convnext, tokenizer, model_llm, device
69
 
70
- model_convnext, tokenizer, model_llm, device = load_models()
71
 
72
  # Image uploader
73
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
@@ -81,12 +81,12 @@ if uploaded_file is not None:
81
  # Predict food label
82
  input_tensor = manual_transform(image).unsqueeze(0).to(device)
83
  with torch.no_grad():
84
- outputs = model_convnext(pixel_values=input_tensor)
85
  pred_idx = outputs.logits.argmax(-1).item()
86
- pred_label = model_convnext.config.id2label[pred_idx]
87
  st.success(f"🍴 Predicted Food: **{pred_label}**")
88
 
89
- # Generate nutrition description with LLM
90
  prompt = (
91
  f"Please provide a concise nutritional overview for a {pred_label}. "
92
  "Include typical serving size, approximate calories, macronutrient breakdown "
@@ -95,16 +95,19 @@ if uploaded_file is not None:
95
  st.subheader("🧾 Nutrition Information")
96
  st.write(f"🤖 Prompt to LLM:\n\n{prompt}")
97
 
98
- input_ids = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
99
- with torch.no_grad():
100
- output = model_llm.generate(
101
- **input_ids,
102
- max_new_tokens=300,
103
- temperature=0.8,
104
- top_p=0.9,
105
- do_sample=True
106
- )
107
- caption = tokenizer.decode(output[0], skip_special_tokens=True).strip()
 
 
 
108
  st.info(caption)
109
 
110
  except Exception as e:
 
4
  from PIL import Image
5
  import torchvision.transforms as transforms
6
  from transformers import (
7
+ ViTForImageClassification,
8
  AutoTokenizer,
9
  AutoModelForCausalLM
10
  )
 
24
  os.makedirs(cache_dir, exist_ok=True)
25
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
26
 
27
+ # Transform for ViT
28
  manual_transform = transforms.Compose([
29
+ transforms.Resize(256),
30
+ transforms.CenterCrop(224),
31
  transforms.ToTensor(),
32
  transforms.Normalize(mean=[0.485, 0.456, 0.406],
33
  std=[0.229, 0.224, 0.225]),
 
37
  # Sidebar info
38
  st.sidebar.header("Models Used")
39
  st.sidebar.markdown("""
40
+ - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_model`
41
  - 💬 **Text Generator**: `TinyLlama/TinyLlama-1.1B-Chat-v1.0`
42
  """)
43
 
 
46
  def load_models():
47
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
 
49
+ # ViT for classification
50
+ model_vit = ViTForImageClassification.from_pretrained(
51
+ "shingguy1/fine_tuned_model",
52
  cache_dir=cache_dir,
53
  token=hf_token
54
  ).to(device)
 
65
  device_map="auto"
66
  )
67
 
68
+ return model_vit, tokenizer, model_llm, device
69
 
70
+ model_vit, tokenizer, model_llm, device = load_models()
71
 
72
  # Image uploader
73
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
 
81
  # Predict food label
82
  input_tensor = manual_transform(image).unsqueeze(0).to(device)
83
  with torch.no_grad():
84
+ outputs = model_vit(pixel_values=input_tensor)
85
  pred_idx = outputs.logits.argmax(-1).item()
86
+ pred_label = model_vit.config.id2label[pred_idx]
87
  st.success(f"🍴 Predicted Food: **{pred_label}**")
88
 
89
+ # Generate nutrition description with LLM (no echo)
90
  prompt = (
91
  f"Please provide a concise nutritional overview for a {pred_label}. "
92
  "Include typical serving size, approximate calories, macronutrient breakdown "
 
95
  st.subheader("🧾 Nutrition Information")
96
  st.write(f"🤖 Prompt to LLM:\n\n{prompt}")
97
 
98
+ inputs = tokenizer(prompt, return_tensors="pt").to(model_llm.device)
99
+ input_len = inputs.input_ids.shape[1]
100
+ output_ids = model_llm.generate(
101
+ **inputs,
102
+ max_new_tokens=200,
103
+ temperature=0.8,
104
+ top_p=0.9,
105
+ do_sample=True,
106
+ eos_token_id=tokenizer.eos_token_id,
107
+ pad_token_id=tokenizer.eos_token_id
108
+ )[0]
109
+ generated_ids = output_ids[input_len:]
110
+ caption = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
111
  st.info(caption)
112
 
113
  except Exception as e: