File size: 6,061 Bytes
977a1e2
20a8ad7
9ffd0bd
 
20a8ad7
 
 
216e067
 
 
977a1e2
216e067
 
 
 
cfcd287
9ffd0bd
 
 
 
 
 
977a1e2
9ffd0bd
 
20a8ad7
9ffd0bd
216e067
 
 
 
 
20a8ad7
 
 
 
9ffd0bd
216e067
 
9ffd0bd
20a8ad7
9ffd0bd
 
216e067
20a8ad7
977a1e2
20a8ad7
9ffd0bd
 
 
 
 
 
 
 
 
 
 
 
 
20a8ad7
 
9ffd0bd
20a8ad7
 
 
 
 
 
9ffd0bd
20a8ad7
9ffd0bd
 
 
977a1e2
9ffd0bd
 
 
 
216e067
 
 
cfcd287
 
 
 
 
 
 
 
 
216e067
 
cfcd287
9ffd0bd
216e067
 
 
 
cfcd287
216e067
 
 
9ffd0bd
 
 
 
20a8ad7
9ffd0bd
20a8ad7
9ffd0bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20a8ad7
9ffd0bd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
977a1e2
9ffd0bd
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
import streamlit as st
from transformers import pipeline
from PIL import Image
import torch
import os

# Set cache directory to avoid permission issues
os.environ["TRANSFORMERS_CACHE"] = "/app/cache/transformers"
os.environ["HF_HOME"] = "/app/cache/hf"
os.environ["HF_HUB_CACHE"] = "/app/cache/hf"

# Set HF token from environment
hf_token = os.getenv("HF_TOKEN")
if hf_token:
    os.environ["HUGGINGFACE_HUB_TOKEN"] = hf_token

# Set page config
st.set_page_config(
    page_title="Gemma-3n E4B Vision-Language Model",
    page_icon="πŸ€–",
    layout="wide"
)

@st.cache_resource
def load_model():
    """Load the model pipeline with caching"""
    try:
        # Check if token is available
        if not hf_token:
            st.error("HF_TOKEN not found in environment variables")
            return None
            
        # Use pipeline approach which is more compatible
        pipe = pipeline(
            "image-text-to-text", 
            model="google/gemma-3n-E4B-it",
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto" if torch.cuda.is_available() else "cpu",
            token=hf_token  # Pass token directly to pipeline
        )
        return pipe
    except Exception as e:
        st.error(f"Error loading model: {str(e)}")
        st.error("Make sure you have access to the model and your token is valid.")
        return None

def generate_response(pipe, image, text_prompt, max_tokens=100):
    """Generate response from the model"""
    try:
        # Prepare messages in the expected format
        messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text", "text": text_prompt}
                ]
            }
        ]
        
        # Generate response using pipeline
        response = pipe(messages, max_new_tokens=max_tokens)
        
        # Extract text from response
        if isinstance(response, list) and len(response) > 0:
            if isinstance(response[0], dict) and 'generated_text' in response[0]:
                return response[0]['generated_text']
            elif isinstance(response[0], str):
                return response[0]
        
        return str(response)
        
    except Exception as e:
        return f"Error generating response: {str(e)}"

def main():
    st.title("πŸ€– Gemma-3n E4B Vision-Language Model")
    st.markdown("Upload an image and ask questions about it!")
    
    # Check if token is available
    if not hf_token:
        st.error("❌ HuggingFace token not found in environment variables.")
        st.markdown("""
        **To fix this:**
        1. Go to your Space settings (βš™οΈ icon)
        2. Navigate to "Repository secrets"
        3. Add a secret with name: `HF_TOKEN`
        4. Value: Your HuggingFace token
        5. Restart the Space
        """)
        return
    else:
        st.success("βœ… HuggingFace token found!")
    
    # Check if user is authenticated
    st.sidebar.markdown("### πŸ“‹ Setup Status")
    st.sidebar.markdown(f"""
    βœ… **Token**: Found in environment
    
    Make sure you have:
    1. βœ… Access to the gated model
    2. βœ… Added your HF token to Space secrets
    3. βœ… Token has proper permissions
    """)
    
    # Load model
    with st.spinner("Loading model... This may take a few minutes on first run."):
        pipe = load_model()
    
    if pipe is None:
        st.error("Failed to load model. Please check your setup and try again.")
        return
    
    st.success("Model loaded successfully!")
    
    # Create two columns
    col1, col2 = st.columns([1, 1])
    
    with col1:
        st.subheader("πŸ“€ Input")
        
        # Image upload
        uploaded_file = st.file_uploader(
            "Choose an image...", 
            type=['png', 'jpg', 'jpeg', 'gif', 'bmp'],
            help="Upload an image to analyze"
        )
        
        # Text input
        text_prompt = st.text_area(
            "Ask a question about the image:",
            placeholder="What do you see in this image?",
            height=100
        )
        
        # Generation parameters
        max_tokens = st.slider(
            "Max tokens to generate:",
            min_value=10,
            max_value=200,
            value=100,
            help="Maximum number of tokens to generate"
        )
        
        # Generate button
        generate_btn = st.button("πŸš€ Generate Response", type="primary")
    
    with col2:
        st.subheader("πŸ“€ Output")
        
        if uploaded_file is not None:
            # Display uploaded image
            image = Image.open(uploaded_file)
            st.image(image, caption="Uploaded image", use_column_width=True)
            
            # Generate response when button is clicked
            if generate_btn:
                if not text_prompt.strip():
                    st.warning("Please enter a question about the image.")
                else:
                    with st.spinner("Generating response..."):
                        response = generate_response(
                            pipe, image, text_prompt, max_tokens
                        )
                    
                    st.subheader("πŸ€– Model Response:")
                    st.write(response)
        else:
            st.info("πŸ‘† Please upload an image to get started")
    
    # Example section
    st.markdown("---")
    st.subheader("πŸ’‘ Example Questions to Try:")
    st.markdown("""
    - What objects do you see in this image?
    - Describe the scene in detail
    - What colors are present in the image?
    - What is the main subject of this image?
    - Can you identify any text in this image?
    """)
    
    # Footer
    st.markdown("---")
    st.markdown(
        "Built with ❀️ using [Streamlit](https://streamlit.io) and "
        "[Hugging Face Transformers](https://huggingface.co/transformers/)"
    )

if __name__ == "__main__":
    main()