artspeak / app.py
coztomate's picture
update app.py
e709466 verified
#import libraries
import streamlit as st
from PIL import Image
import io
from openai import OpenAI
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from text_generation import Client
from huggingface_hub import InferenceClient
import config_llm
# Initialize session state variables
if 'user_input' not in st.session_state:
st.session_state['user_input'] = ""
if 'simplified_text' not in st.session_state:
st.session_state['simplified_text'] = ''
if 'new_caption' not in st.session_state:
st.session_state['new_caption'] = None
if 'model_clip' not in st.session_state:
st.session_state['model_clip'] = None
if 'transform_clip' not in st.session_state:
st.session_state['transform_clip'] = None
if 'openai_api_key' not in st.session_state:
st.session_state['openai_api_key'] = ''
if 'huggingface_key' not in st.session_state:
st.session_state['huggingface_key'] = ''
if 'message_content_from_caption' not in st.session_state:
st.session_state['message_content_from_caption'] = ''
if 'message_content_from_simplified_text' not in st.session_state:
st.session_state['message_content_from_simplified_text'] = ''
if 'mixtral_from_caption' not in st.session_state:
st.session_state['mixtral_from_caption'] = ''
if 'mixtral_from_simplified' not in st.session_state:
st.session_state['mixtral_from_simplified'] = ''
if 'image_from_caption' not in st.session_state:
st.session_state['image_from_caption'] = None
if 'image_from_simplified_text' not in st.session_state:
st.session_state['image_from_simplified_text'] = None
if 'image_from_press_text' not in st.session_state:
st.session_state['image_from_press_text'] = None
if 'image_from_press_text_from_caption' not in st.session_state:
st.session_state['image_from_press_text_from_caption'] = None
# Load the tokenizer and simplifier model
tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-small-finetuned-text-simplification")
model = AutoModelForSeq2SeqLM.from_pretrained("mrm8488/t5-small-finetuned-text-simplification")
# Function to simplify text
def simplify_text(input_text):
# Tokenize and encode the input text
input_ids = tokenizer.encode("simplify: " + input_text, return_tensors="pt")
# Generate the simplified text
output = model.generate(input_ids, min_length=5, max_length=80, do_sample=True)
# Decode the simplified text
simplified_text = tokenizer.decode(output[0], skip_special_tokens=True)
# Post-process to ensure the output ends with a complete sentence
# Find the last period, question mark, or exclamation point
last_valid_ending = max(simplified_text.rfind('.'), simplified_text.rfind('?'), simplified_text.rfind('!'))
if last_valid_ending != -1:
# Ensure the output ends with the last complete sentence
cleaned_text = simplified_text[:last_valid_ending+1]
else:
# No sentence ending found; return the whole text or handle as appropriate
cleaned_text = simplified_text
return cleaned_text
# Define the path to example text
example_text_path = "example_text.txt"
# Function to load example text from a file
def load_example_text():
with open(example_text_path, "r", encoding="utf-8") as file:
return file.read()
# Define the path to your example image
example_image_path = "example.jpg"
# Function to load image from file
def load_image(image_path):
with open(image_path, "rb") as file:
# Open the image using PIL
img = Image.open(file)
# Load the image data into memory
img.load()
return img
#get huggingface key
st.session_state['huggingface_key'] = st.secrets["hf_key"]
client = InferenceClient(token=st.session_state['huggingface_key'])
########################################################################
# Create a Streamlit app
st.title("ARTSPEAK > s i m p l i f i e r")
st.markdown("---")
# Create a sub-section for uploading the files
with st.expander("Upload Files"):
st.markdown("## Upload Text and Image")
##### Upload of files
st.write("Paste your text here or upload example:")
# Add a button to load example text into the text area
if st.button('Load Example Text'):
# Update the session state for user input with the example text
st.session_state['user_input'] = load_example_text()
# Add a text input field for user input
# Directly use session state variable for the value parameter
user_input = st.text_area("Enter text here", value=st.session_state['user_input'])
st.markdown("---")
# Load and display example image separately and save for further use
if st.button("Load Example Image"):
st.session_state['example_image'] = load_image(example_image_path)
st.image(st.session_state['example_image'], caption="Example Image")
# Displaying the file uploader
uploaded_image = st.file_uploader("Upload an image (jpg or png)", type=["jpg", "png"])
st.markdown("---")
#### Simplifier and Image Caption
with st.expander("Simplify Text and Image"):
st.markdown("## 'Simplify' Text and Image")
## Text simplifier
if st.button("Simplify the Input Text"):
if user_input:
simplified_text = simplify_text(user_input)
st.session_state['simplified_text'] = simplified_text
else:
st.warning("Please enter text in the input field before clicking 'Save'")
# Display the simplified text from session state
if st.session_state['simplified_text']:
st.write(st.session_state['simplified_text'])
## Get new caption
# Button to get new caption
if st.button("Get New Caption for Image"):
# Initialize image data variable
image_data = None
# Check if the user has uploaded an image
if uploaded_image is not None:
image_data = uploaded_image.getvalue()
# If not, check if the example image has been loaded
elif 'example_image' in st.session_state:
# Convert PIL Image to bytes for example image
buffer = io.BytesIO()
st.session_state['example_image'].save(buffer, format="PNG")
buffer.seek(0)
image_data = buffer.getvalue()
# If we have image data, get the caption
if image_data is not None:
try:
# Generate the caption (make sure to send the image in the correct format expected by your API)
caption = client.image_to_text(image_data)
# Update the session state
st.session_state['new_caption'] = caption
st.write(st.session_state['new_caption'])
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please upload an image or load the example image before clicking 'Get New Caption for Image'")
st.markdown("---")
########################################################################
with st.expander("Press Text Generation"):
st.markdown("## Generate New Presstext for an Exhibition")
# Define radio button options
option = st.radio(
"Choose a Language Model:",
('Mixtral 8x7B', 'GPT-3.5 Turbo'))
# Conditional logic based on radio button choice
if option == 'Mixtral 8x7B':
st.header("Mixtral 8x7B")
############
###Mixtral##
############
headers = {"Authorization": f"Bearer {st.session_state['huggingface_key']}"}
client_mixtral = Client(
config_llm.API_URL,
headers=headers,
)
def run_single_input(
message: str,
system_prompt: str = config_llm.DEFAULT_SYSTEM_PROMPT,
max_new_tokens: int = config_llm.MAX_NEW_TOKENS,
temperature: float = config_llm.TEMPERATURE,
top_p: float = config_llm.TOP_P
) -> str:
"""
Run the model for a single input and return a single output.
"""
prompt = f"{system_prompt}\n\nUser: {message.strip()}\n"
generate_kwargs = dict(
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
temperature=temperature,
)
stream = client_mixtral.generate_stream(prompt, **generate_kwargs)
output = ""
for response in stream:
if any([end_token in response.token.text for end_token in [config_llm.EOS_STRING, config_llm.EOT_STRING]]):
break # Stop at the first end token
else:
output += response.token.text
return output.strip() # Return the complete output
# Button to generate press text from new caption from Mixtral
if st.button("Generate Press Text from New Image Caption with Mixtral"):
if st.session_state['new_caption']:
try:
st.session_state['mixtral_from_caption'] = run_single_input(st.session_state['new_caption'], config_llm.DEFAULT_SYSTEM_PROMPT)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please ensure a caption is generated.")
# Display the generated press text from new caption
if st.session_state['mixtral_from_caption']:
st.write("Generated Press Text from New Caption of Artwork:")
st.write(st.session_state['mixtral_from_caption'])
# Button to generate press text from simplified text
if st.button("Generate Press Text from Simplified Text with Mixtral"):
if st.session_state['simplified_text']:
try:
st.session_state['mixtral_from_simplified'] = run_single_input(st.session_state['simplified_text'], config_llm.DEFAULT_SYSTEM_PROMPT)
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please ensure simplified text is available.")
# Display the generated press text from simplified text
if st.session_state['mixtral_from_simplified']:
st.write("Generated Press Text from Simplified Text:")
st.write(st.session_state['mixtral_from_simplified'])
elif option == 'GPT-3.5 Turbo':
st.header("GPT-3.5")
##########
##OpenAI##
#########
# Add a text input for the OpenAI API key
api_key_input = st.text_input("Enter your OpenAI API key to continue", type="password")
# Button to save the API key
if st.button('Save API Key'):
st.session_state['openai_api_key'] = api_key_input
st.success("API Key saved temporarily for this session.")
st.write("- - -")
# Function to get completion from OpenAI API
def get_openai_completion(api_key, prompt_message):
client = OpenAI(api_key=api_key,)
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
max_tokens=config_llm.MAX_NEW_TOKENS,
temperature = config_llm.TEMPERATURE,
top_p = config_llm.TOP_P,
messages=[
{"role": "system", "content": config_llm.DEFAULT_SYSTEM_PROMPT},
{"role": "user", "content": prompt_message}
]
)
return completion.choices[0].message.content
# Button to generate press text from new caption
if st.button("Generate Press Text from New Image Caption with GPT"):
if st.session_state['new_caption'] and st.session_state['openai_api_key']:
try:
st.session_state['message_content_from_caption'] = get_openai_completion(st.session_state['openai_api_key'], st.session_state['new_caption'])
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please ensure a caption is generated and an API key is entered.")
# Display the generated press text from new caption
if st.session_state['message_content_from_caption']:
st.write("Generated Press Text from New Caption of Artwork:")
st.write(st.session_state['message_content_from_caption'])
# Button to generate press text from simplified text
if st.button("Generate Press Text from Simplified Text with GPT"):
if st.session_state['simplified_text'] and st.session_state['openai_api_key']:
try:
st.session_state['message_content_from_simplified_text'] = get_openai_completion(st.session_state['openai_api_key'], st.session_state['simplified_text'])
except Exception as e:
st.error(f"An error occurred: {e}")
else:
st.warning("Please ensure simplified text is available and an API key is entered.")
# Display the generated press text from simplified text
if st.session_state['message_content_from_simplified_text']:
st.write("Generated Press Text from Simplified Text:")
st.write(st.session_state['message_content_from_simplified_text'])
st.markdown("---")
########################################################################
## Image Generation Interface
with st.expander("Image Generation"):
st.markdown("## Generate new Images from Texts")
# Button to generate image from new caption
if st.button("Generate Image from New Caption of Artwork"):
if st.session_state['new_caption']:
prompt_caption = f"contemporary art of {st.session_state['new_caption']}"
st.session_state['image_from_caption'] = client.text_to_image(prompt_caption, model="prompthero/openjourney-v4")
# Display the image generated from new caption
if st.session_state['image_from_caption'] is not None:
st.image(st.session_state['image_from_caption'], caption="Image from New Caption", use_column_width=True)
# Button to generate image from simplified text
if st.button("Generate Image from Simplified Text"):
if st.session_state['simplified_text']:
prompt_summary = f"contemporary art of {st.session_state['simplified_text']}"
st.session_state['image_from_simplified_text'] = client.text_to_image(prompt_summary, model="prompthero/openjourney-v4")
# Display the image generated from simplified text
if st.session_state['image_from_simplified_text'] is not None:
st.image(st.session_state['image_from_simplified_text'], caption="Image from Simplified Text", use_column_width=True)
# Button to generate image from press text from simplified text
if st.button("Generate Image from new Press Text from Simplified Text"):
text_to_use_simp = None
# Check which variable is available and set it to text_to_use
if 'mixtral_from_simplified' in st.session_state and st.session_state['mixtral_from_simplified']:
text_to_use_simp = st.session_state['mixtral_from_simplified']
elif 'message_content_from_simplified_text' in st.session_state and st.session_state['message_content_from_simplified_text']:
text_to_use_simp = st.session_state['message_content_from_simplified_text']
# Use the available text to generate the image
if text_to_use_simp:
# Check for length of the text and truncate if necessary
if len(text_to_use_simp) > 509: # Adjust based on your model's max length (512-3)
text_to_use_simp = text_to_use_simp[:509] # Truncate the text
prompt_press_text_simple = f"contemporary art of {text_to_use_simp}"
try:
st.session_state['image_from_press_text'] = client.text_to_image(prompt_press_text_simple, model="prompthero/openjourney-v4")
except Exception as e:
st.error("Failed to generate image: " + str(e))
else:
st.error("First generate a press text from summary.")
# Display the image generated from press text from simplified text
if 'image_from_press_text' in st.session_state and st.session_state['image_from_press_text'] is not None:
st.image(st.session_state['image_from_press_text'],
caption="Image from Press Text from simplified Text",
use_column_width=True)
# Button to generate image from press text from caption
if st.button("Generate Image from new Press Text from new Caption"):
# Initialize the variable
text_to_use_cap = None
# Check which variable is available and set it to text_to_use
if 'mixtral_from_caption' in st.session_state and st.session_state['mixtral_from_caption']:
text_to_use_cap = st.session_state['mixtral_from_caption']
elif 'message_content_from_caption' in st.session_state and st.session_state['message_content_from_caption']:
text_to_use_cap = st.session_state['message_content_from_caption']
# Use the available text to generate the image
if text_to_use_cap:
# Check for length of the text and truncate if necessary
if len(text_to_use_cap) > 509: # Adjust based on your model's max length
text_to_use_cap = text_to_use_cap[:509] # Truncate the text
prompt_press_text_caption = f"contemporary art of {text_to_use_cap}"
try:
st.session_state['image_from_press_text_from_caption'] = client.text_to_image(prompt_press_text_caption, model="prompthero/openjourney-v4")
except Exception as e:
st.error("Failed to generate image: " + str(e))
else:
st.error("First generate a press text from summary.")
# Display the image generated from press text from caption
if st.session_state['image_from_press_text_from_caption'] is not None:
st.image(st.session_state['image_from_press_text_from_caption'],
caption="Image from Press Text from new Caption",
use_column_width=True)
st.markdown("---")