coztomate commited on
Commit
1302286
·
1 Parent(s): 68232c6

20231203_KD: first app

Browse files
Files changed (2) hide show
  1. artspeak_app_smaller.py +116 -0
  2. requirements.txt +9 -0
artspeak_app_smaller.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #import libraries
2
+ import streamlit as st
3
+ from PIL import Image
4
+ import io
5
+ from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
6
+ import torch
7
+ from torchvision import transforms
8
+ import open_clip
9
+
10
+ # Initialize session state variables
11
+ if 'simplified_text' not in st.session_state:
12
+ st.session_state['simplified_text'] = ''
13
+ if 'new_caption' not in st.session_state:
14
+ st.session_state['new_caption'] = ''
15
+ if 'model_clip' not in st.session_state:
16
+ st.session_state['model_clip'] = None
17
+ if 'transform_clip' not in st.session_state:
18
+ st.session_state['transform_clip'] = None
19
+
20
+ # Define model and tokenizer names for the text simplification model
21
+ model_name = "mrm8488/t5-small-finetuned-text-simplification"
22
+ tokenizer_name = "mrm8488/t5-small-finetuned-text-simplification"
23
+
24
+ # Load models only once in session state
25
+ if 'model' not in st.session_state or 'tokenizer' not in st.session_state:
26
+ st.session_state['model'] = AutoModelForSeq2SeqLM.from_pretrained(model_name)
27
+ st.session_state['tokenizer'] = AutoTokenizer.from_pretrained(tokenizer_name)
28
+ st.session_state['simplifier'] = pipeline("text2text-generation", model=st.session_state['model'], tokenizer=st.session_state['tokenizer'])
29
+
30
+ # Use the model from session state
31
+ simplifier = st.session_state['simplifier']
32
+
33
+ # Function to load CLIP model
34
+ def load_clip_model():
35
+ model_clip, _, transform_clip = open_clip.create_model_and_transforms(
36
+ model_name="coca_ViT-L-14",
37
+ pretrained="mscoco_finetuned_laion2B-s13B-b90k"
38
+ )
39
+ return model_clip, transform_clip
40
+
41
+ # Function to generate a caption for the uploaded image
42
+ def generate_caption(image_path):
43
+ # Load the CLIP model if it hasn't been loaded yet
44
+ if st.session_state['model_clip'] is None or st.session_state['transform_clip'] is None:
45
+ st.session_state['model_clip'], st.session_state['transform_clip'] = load_clip_model()
46
+
47
+ # Load and preprocess the uploaded image
48
+ im = Image.open(image_path).convert("RGB")
49
+ im = st.session_state['transform_clip'](im).unsqueeze(0)
50
+
51
+ # Generate a caption for the image
52
+ with torch.no_grad(), torch.cuda.amp.autocast():
53
+ generated = st.session_state['model_clip'].generate(im)
54
+
55
+ new_caption = open_clip.decode(generated[0]).split("<end_of_text>")[0].replace("<start_of_text>", "")[:-2]
56
+ return new_caption
57
+
58
+
59
+
60
+ # Create a Streamlit app
61
+ st.title("ARTSPEAK")
62
+
63
+ ##### Upload of files
64
+ # Add a text input field for user input
65
+ user_input = st.text_area("Enter text here")
66
+
67
+ # Add an upload field to the app for image files (jpg or png)
68
+ uploaded_image = st.file_uploader("Upload an image (jpg or png)", type=["jpg", "png"])
69
+
70
+ #### Display of files
71
+ # Create a sub-section
72
+ with st.expander("Display of Uploaded Files"):
73
+ st.write("These are you uploaded files:")
74
+ # Check if a file was uploaded
75
+ if user_input is not None:
76
+ # Display file information
77
+ st.write("Original Text:")
78
+ st.write(user_input)
79
+
80
+ # Check if an image was uploaded
81
+ if uploaded_image is not None:
82
+ # Display the uploaded image
83
+ st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
84
+
85
+
86
+ ####Get summary
87
+ if st.button("Simplify"):
88
+ if user_input:
89
+ simplified_text = simplifier(user_input, min_length=20, max_length=50, do_sample=True)
90
+ # Update the session state
91
+ st.session_state['simplified_text'] = simplified_text[0]['generated_text']
92
+ else:
93
+ st.warning("Please enter text in the input field before clicking 'Save'")
94
+
95
+ # Display the simplified text from session state
96
+ if st.session_state['simplified_text']:
97
+ st.write("Simplified Text:")
98
+ st.write(st.session_state['simplified_text'])
99
+
100
+ ####Get new caption
101
+
102
+
103
+ # Modify the 'Get Caption' button section
104
+ if st.button("Get Caption"):
105
+ if uploaded_image is not None:
106
+ # Generate the caption
107
+ caption = generate_caption(uploaded_image)
108
+ # Update the session state
109
+ st.session_state['new_caption'] = caption
110
+ else:
111
+ st.warning("Please upload an image before clicking 'Get Caption'")
112
+
113
+ # Display the new caption from session state
114
+ if st.session_state['new_caption']:
115
+ st.write("New Caption for Artwork:")
116
+ st.write(st.session_state['new_caption'])
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ streamlit
2
+ numpy
3
+ pandas
4
+ openai
5
+ open_clip_torch
6
+ transformers
7
+ accelerate
8
+ openai
9
+ diffusers