shingguy1 commited on
Commit
aaf3765
·
verified ·
1 Parent(s): 749ea77

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +26 -27
src/streamlit_app.py CHANGED
@@ -16,13 +16,13 @@ from transformers import (
16
  )
17
 
18
  def main():
19
- # 2. Environment & cache
20
  hf_token = os.getenv("HF_TOKEN", None)
21
  cache_dir = "/tmp/cache"
22
  os.makedirs(cache_dir, exist_ok=True)
23
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
24
 
25
- # 3. Image transform for ViT
26
  manual_transform = transforms.Compose([
27
  transforms.Resize(256),
28
  transforms.CenterCrop(224),
@@ -33,26 +33,26 @@ def main():
33
  transforms.ConvertImageDtype(torch.float32)
34
  ])
35
 
36
- # 4. Sidebar info
37
  st.sidebar.header("Models Used")
38
  st.sidebar.markdown("""
39
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
40
  - 💬 **Text Generator**: `tiiuae/falcon-7b-instruct`
41
  """)
42
 
43
- # 5. Load models
44
  @st.cache_resource
45
  def load_models():
46
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
 
48
- # ViT classifier
49
  model_vit = ViTForImageClassification.from_pretrained(
50
  "shingguy1/fine_tuned_vit",
51
  cache_dir=cache_dir,
52
  use_auth_token=hf_token
53
  ).to(device)
54
 
55
- # Falcon-7B Instruct LLM
56
  tokenizer_llm = AutoTokenizer.from_pretrained(
57
  "tiiuae/falcon-7b-instruct",
58
  cache_dir=cache_dir,
@@ -61,33 +61,32 @@ def main():
61
  model_llm = AutoModelForCausalLM.from_pretrained(
62
  "tiiuae/falcon-7b-instruct",
63
  cache_dir=cache_dir,
64
- use_auth_token=hf_token,
 
65
  torch_dtype=torch.float16,
66
- device_map="auto"
67
  )
68
 
69
  return model_vit, tokenizer_llm, model_llm, device
70
 
71
  model_vit, tokenizer_llm, model_llm, device = load_models()
72
 
73
- # 6. Image uploader
74
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
75
-
76
  if uploaded_file is not None:
77
  try:
78
- # Display image
79
  image = Image.open(uploaded_file)
80
  st.image(image, caption="Uploaded Image", use_column_width=True)
81
 
82
- # Classify with ViT
83
- input_tensor = manual_transform(image).unsqueeze(0).to(device)
84
  with torch.no_grad():
85
- outputs = model_vit(pixel_values=input_tensor)
86
- pred_idx = outputs.logits.argmax(-1).item()
87
- pred_label = model_vit.config.id2label[pred_idx]
88
  st.success(f"🍴 Predicted Food: **{pred_label}**")
89
 
90
- # Build prompt
91
  prompt = (
92
  "### Instruction\n"
93
  f"Provide a concise nutritional overview for a {pred_label}, including:\n"
@@ -104,12 +103,12 @@ def main():
104
 
105
  # Tokenize & generate
106
  inputs = tokenizer_llm(prompt, return_tensors="pt")
107
- inputs = {k: v.to(device) for k, v in inputs.items()}
108
- input_len = inputs["input_ids"].shape[1]
109
 
110
- outputs = model_llm.generate(
111
  **inputs,
112
- max_length=input_len + 150,
113
  temperature=0.7,
114
  top_p=0.9,
115
  do_sample=True,
@@ -117,14 +116,14 @@ def main():
117
  early_stopping=True,
118
  pad_token_id=tokenizer_llm.eos_token_id,
119
  eos_token_id=tokenizer_llm.eos_token_id
120
- )
121
 
122
- # Decode and strip prompt
123
- full = tokenizer_llm.decode(outputs[0], skip_special_tokens=True).strip()
124
- if "### Response" in full:
125
- caption = full.split("### Response", 1)[1].strip()
126
  else:
127
- caption = full[input_len:].strip()
128
 
129
  if caption:
130
  st.info(caption)
 
16
  )
17
 
18
  def main():
19
+ # Environment & cache
20
  hf_token = os.getenv("HF_TOKEN", None)
21
  cache_dir = "/tmp/cache"
22
  os.makedirs(cache_dir, exist_ok=True)
23
  os.environ["HUGGINGFACE_HUB_CACHE"] = cache_dir
24
 
25
+ # Image transform for ViT
26
  manual_transform = transforms.Compose([
27
  transforms.Resize(256),
28
  transforms.CenterCrop(224),
 
33
  transforms.ConvertImageDtype(torch.float32)
34
  ])
35
 
36
+ # Sidebar info
37
  st.sidebar.header("Models Used")
38
  st.sidebar.markdown("""
39
  - 🖼️ **Image Classifier**: `shingguy1/fine_tuned_vit`
40
  - 💬 **Text Generator**: `tiiuae/falcon-7b-instruct`
41
  """)
42
 
43
+ # Load models (cached)
44
  @st.cache_resource
45
  def load_models():
46
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
 
48
+ # ViT classifier → GPU/CPU
49
  model_vit = ViTForImageClassification.from_pretrained(
50
  "shingguy1/fine_tuned_vit",
51
  cache_dir=cache_dir,
52
  use_auth_token=hf_token
53
  ).to(device)
54
 
55
+ # Falcon-7B Instruct → 8-bit quant on GPU
56
  tokenizer_llm = AutoTokenizer.from_pretrained(
57
  "tiiuae/falcon-7b-instruct",
58
  cache_dir=cache_dir,
 
61
  model_llm = AutoModelForCausalLM.from_pretrained(
62
  "tiiuae/falcon-7b-instruct",
63
  cache_dir=cache_dir,
64
+ load_in_8bit=True,
65
+ device_map="auto",
66
  torch_dtype=torch.float16,
67
+ use_auth_token=hf_token
68
  )
69
 
70
  return model_vit, tokenizer_llm, model_llm, device
71
 
72
  model_vit, tokenizer_llm, model_llm, device = load_models()
73
 
74
+ # Image uploader
75
  uploaded_file = st.file_uploader("Upload a food image...", type=["jpg", "jpeg", "png"])
 
76
  if uploaded_file is not None:
77
  try:
 
78
  image = Image.open(uploaded_file)
79
  st.image(image, caption="Uploaded Image", use_column_width=True)
80
 
81
+ # Classify
82
+ inputs_v = manual_transform(image).unsqueeze(0).to(device)
83
  with torch.no_grad():
84
+ out = model_vit(pixel_values=inputs_v)
85
+ idx = out.logits.argmax(-1).item()
86
+ pred_label = model_vit.config.id2label[idx]
87
  st.success(f"🍴 Predicted Food: **{pred_label}**")
88
 
89
+ # Unified instruction prompt
90
  prompt = (
91
  "### Instruction\n"
92
  f"Provide a concise nutritional overview for a {pred_label}, including:\n"
 
103
 
104
  # Tokenize & generate
105
  inputs = tokenizer_llm(prompt, return_tensors="pt")
106
+ inputs = {k: v.to(model_llm.device) for k, v in inputs.items()}
107
+ inp_len = inputs["input_ids"].shape[1]
108
 
109
+ out_ids = model_llm.generate(
110
  **inputs,
111
+ max_length=inp_len + 150,
112
  temperature=0.7,
113
  top_p=0.9,
114
  do_sample=True,
 
116
  early_stopping=True,
117
  pad_token_id=tokenizer_llm.eos_token_id,
118
  eos_token_id=tokenizer_llm.eos_token_id
119
+ )[0]
120
 
121
+ # Decode & strip prompt
122
+ decoded = tokenizer_llm.decode(out_ids, skip_special_tokens=True).strip()
123
+ if "### Response" in decoded:
124
+ caption = decoded.split("### Response", 1)[1].strip()
125
  else:
126
+ caption = decoded[inp_len:].strip()
127
 
128
  if caption:
129
  st.info(caption)