daniloedu commited on
Commit
20a8ad7
·
verified ·
1 Parent(s): 5996fc9

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +26 -36
src/streamlit_app.py CHANGED
@@ -1,8 +1,12 @@
1
  import streamlit as st
2
- from transformers import AutoProcessor, AutoModelForImageTextToText
3
  from PIL import Image
4
  import torch
5
- import io
 
 
 
 
6
 
7
  # Set page config
8
  st.set_page_config(
@@ -13,21 +17,22 @@ st.set_page_config(
13
 
14
  @st.cache_resource
15
  def load_model():
16
- """Load the model and processor with caching"""
17
  try:
18
- processor = AutoProcessor.from_pretrained("google/gemma-3n-E4B-it")
19
- model = AutoModelForImageTextToText.from_pretrained(
20
- "google/gemma-3n-E4B-it",
 
21
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
22
  device_map="auto" if torch.cuda.is_available() else "cpu"
23
  )
24
- return processor, model
25
  except Exception as e:
26
  st.error(f"Error loading model: {str(e)}")
27
  st.error("Make sure you have access to the model and are logged in to HuggingFace.")
28
- return None, None
29
 
30
- def generate_response(processor, model, image, text_prompt, max_tokens=100):
31
  """Generate response from the model"""
32
  try:
33
  # Prepare messages in the expected format
@@ -41,32 +46,17 @@ def generate_response(processor, model, image, text_prompt, max_tokens=100):
41
  }
42
  ]
43
 
44
- # Process inputs
45
- inputs = processor.apply_chat_template(
46
- messages,
47
- add_generation_prompt=True,
48
- tokenize=True,
49
- return_dict=True,
50
- return_tensors="pt",
51
- ).to(model.device)
52
-
53
- # Generate response
54
- with torch.no_grad():
55
- outputs = model.generate(
56
- **inputs,
57
- max_new_tokens=max_tokens,
58
- do_sample=True,
59
- temperature=0.7,
60
- pad_token_id=processor.tokenizer.eos_token_id
61
- )
62
 
63
- # Decode response
64
- response = processor.decode(
65
- outputs[0][inputs["input_ids"].shape[-1]:],
66
- skip_special_tokens=True
67
- )
 
68
 
69
- return response
70
 
71
  except Exception as e:
72
  return f"Error generating response: {str(e)}"
@@ -91,9 +81,9 @@ def main():
91
 
92
  # Load model
93
  with st.spinner("Loading model... This may take a few minutes on first run."):
94
- processor, model = load_model()
95
 
96
- if processor is None or model is None:
97
  st.error("Failed to load model. Please check your setup and try again.")
98
  return
99
 
@@ -146,7 +136,7 @@ def main():
146
  else:
147
  with st.spinner("Generating response..."):
148
  response = generate_response(
149
- processor, model, image, text_prompt, max_tokens
150
  )
151
 
152
  st.subheader("🤖 Model Response:")
 
1
  import streamlit as st
2
+ from transformers import pipeline
3
  from PIL import Image
4
  import torch
5
+ import os
6
+
7
+ # Set cache directory to avoid permission issues
8
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/transformers_cache"
9
+ os.environ["HF_HOME"] = "/tmp/hf_home"
10
 
11
  # Set page config
12
  st.set_page_config(
 
17
 
18
  @st.cache_resource
19
  def load_model():
20
+ """Load the model pipeline with caching"""
21
  try:
22
+ # Use pipeline approach which is more compatible
23
+ pipe = pipeline(
24
+ "image-text-to-text",
25
+ model="google/gemma-3n-E4B-it",
26
  torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
27
  device_map="auto" if torch.cuda.is_available() else "cpu"
28
  )
29
+ return pipe
30
  except Exception as e:
31
  st.error(f"Error loading model: {str(e)}")
32
  st.error("Make sure you have access to the model and are logged in to HuggingFace.")
33
+ return None
34
 
35
+ def generate_response(pipe, image, text_prompt, max_tokens=100):
36
  """Generate response from the model"""
37
  try:
38
  # Prepare messages in the expected format
 
46
  }
47
  ]
48
 
49
+ # Generate response using pipeline
50
+ response = pipe(messages, max_new_tokens=max_tokens)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
+ # Extract text from response
53
+ if isinstance(response, list) and len(response) > 0:
54
+ if isinstance(response[0], dict) and 'generated_text' in response[0]:
55
+ return response[0]['generated_text']
56
+ elif isinstance(response[0], str):
57
+ return response[0]
58
 
59
+ return str(response)
60
 
61
  except Exception as e:
62
  return f"Error generating response: {str(e)}"
 
81
 
82
  # Load model
83
  with st.spinner("Loading model... This may take a few minutes on first run."):
84
+ pipe = load_model()
85
 
86
+ if pipe is None:
87
  st.error("Failed to load model. Please check your setup and try again.")
88
  return
89
 
 
136
  else:
137
  with st.spinner("Generating response..."):
138
  response = generate_response(
139
+ pipe, image, text_prompt, max_tokens
140
  )
141
 
142
  st.subheader("🤖 Model Response:")