| import streamlit as st |
| import replicate |
| import os |
| from transformers import AutoTokenizer |
|
|
| |
|
|
| |
| with st.sidebar: |
| st.title('Satellite VQA') |
| |
| |
| api_token_available = False |
| try: |
| os.environ['REPLICATE_API_TOKEN'] = st.secrets["REPLICATE_API_TOKEN"] |
| api_token_available = True |
| except KeyError: |
| st.error("Replicate API token not found in secrets. Please add it to continue.") |
| st.info("You can set this up in your Streamlit secrets.toml file.") |
|
|
| st.subheader("Models and parameters") |
| model = "anthropic/claude-3.5-sonnet" |
| |
| temperature = st.sidebar.slider('temperature', min_value=0.01, max_value=5.0, value=0.7, step=0.01, |
| help="Randomness of generated output") |
| if temperature >= 1: |
| st.warning('Values exceeding 1 produces more creative and random output as well as increased likelihood of hallucination.') |
| if temperature < 0.1: |
| st.warning('Values approaching 0 produces deterministic output. Recommended starting value is 0.7') |
| |
| top_p = st.sidebar.slider('top_p', min_value=0.01, max_value=1.0, value=0.9, step=0.01, |
| help="Top p percentage of most likely tokens for output generation") |
|
|
| |
| image = st.file_uploader("Upload a satellite image", type=["png", "jpg", "jpeg"]) |
| if image: |
| st.image(image, caption="Satellite Image", use_column_width=True) |
|
|
| |
| if "messages" not in st.session_state.keys(): |
| st.session_state.messages = [{"role": "assistant", "content": "Ask me a question about this satellite image."}] |
|
|
| |
| for message in st.session_state.messages: |
| with st.chat_message(message["role"]): |
| st.write(message["content"]) |
|
|
| def clear_chat_history(): |
| st.session_state.messages = [{"role": "assistant", "content": "Ask me a question about this satellite image."}] |
|
|
| st.sidebar.button('Clear chat history', on_click=clear_chat_history) |
|
|
| @st.cache_resource(show_spinner=False) |
| def get_tokenizer(): |
| """Get a tokenizer to make sure we're not sending too much text |
| to the Model. Eventually we will replace this with ArcticTokenizer |
| """ |
| return AutoTokenizer.from_pretrained("huggyllama/llama-7b") |
|
|
| def get_num_tokens(prompt): |
| """Get the number of tokens in a given prompt""" |
| tokenizer = get_tokenizer() |
| tokens = tokenizer.tokenize(prompt) |
| return len(tokens) |
|
|
| |
| def generate_response(): |
| prompt = [] |
| for dict_message in st.session_state.messages: |
| if dict_message["role"] == "user": |
| prompt.append("<|im_start|>user\n" + dict_message["content"] + "<|im_end|>") |
| else: |
| prompt.append("<|im_start|>assistant\n" + dict_message["content"] + "<|im_end|>") |
| |
| prompt.append("<|im_start|>assistant") |
| prompt.append("") |
| prompt_str = "\n".join(prompt) |
| |
| if get_num_tokens(prompt_str) >= 3072: |
| st.error("Conversation length too long. Please keep it under 3072 tokens.") |
| st.button('Clear chat history', on_click=clear_chat_history, key="clear_chat_history") |
| st.stop() |
|
|
| try: |
| for event in replicate.stream(model, |
| input={"image": image, |
| "prompt": prompt_str, |
| "prompt_template": r"{prompt}", |
| "temperature": temperature, |
| "top_p": top_p, |
| }): |
| yield str(event) |
| except Exception as e: |
| yield f"Error generating response: {str(e)}" |
|
|
| |
| if prompt := st.chat_input(disabled=not api_token_available): |
| st.session_state.messages.append({"role": "user", "content": prompt}) |
| with st.chat_message("user"): |
| st.write(prompt) |
|
|
| |
| if st.session_state.messages[-1]["role"] != "assistant": |
| with st.chat_message("assistant"): |
| response = generate_response() |
| full_response = st.write_stream(response) |
| message = {"role": "assistant", "content": full_response} |
| st.session_state.messages.append(message) |
| else: |
| st.info("๐ Please upload a satellite image to start the conversation.") |