donut-labelling / app.py
jonathanjordan21's picture
Update app.py
6e2743a
import streamlit as st
import zipfile, os, json
from io import StringIO
if 'data' not in st.session_state:
st.session_state.data = 0
if 'download' not in st.session_state:
st.session_state.download = False
if 'labels' not in st.session_state:
st.session_state.labels = []
if 'images' not in st.session_state:
st.session_state.images = []
c1, c2 = st.columns(2)
with c1:
if st.button("Add new data"):
st.session_state.data += 1
with c2:
if st.button('Remove last data'):
if st.session_state.data > 0 :
st.session_state.data -= 1
def json_to_tokens(data):
if isinstance(data, dict):
result = ""
for key, value in data.items():
result += f"<{key}>"
result += json_to_tokens(value)
result += f"</{key}>"
return result
elif isinstance(data, list):
result = ""
for item in data:
result += json_to_tokens(item)
result += "<sep/>"
result = result[:-6]
return result
else:
return str(data)
with st.form("my_form"):
col1, col2 = st.columns(2)
for x in range(st.session_state.data):
with col1:
f = st.file_uploader("Upload Image",type=['png', 'jpg', 'jpeg'], key=f'image_{x}')
# st.session_state.labels.append(f)
with col2:
txt = st.text_area("""Json Text""", key=f'label_{x}')
# st.session_state.labels.append(txt)
submitted = st.form_submit_button("Submit")
if submitted:
text = []
images = []
for x in range(st.session_state.data):
k = json.loads(st.session_state[f'label_{x}'])
text.append(k)
images.append(st.session_state[f'image_{x}'])
words = []
for i,m in enumerate(text):
words.append('<s_start>')
words[i] += json_to_tokens(m)
words[i] += '</s_start>'
with zipfile.ZipFile('images.zip', 'w') as img_zip:
for i,image in enumerate(images):
img_name = f"{i}_image.jpg"
img_zip.writestr(img_name, image.read())
with zipfile.ZipFile("labels.zip", "w") as zip:
for i,t in enumerate(words):
txt_name = f"{i}_label.txt"
zip.writestr(txt_name, t)
st.session_state.download = True
if st.session_state.download:
c_1, c_2 = st.columns(2)
with c_1:
with open("images.zip", "rb") as file:
btn = st.download_button(
label="Download Images",
data=file,
file_name="images.zip")
with c_2:
with open("labels.zip", "rb") as file:
btn = st.download_button(
label="Download Labels",
data=file,
file_name="labels.zip")