Spaces:
Runtime error
Runtime error
ydshieh
commited on
Commit
·
d28411b
1
Parent(s):
3568832
use real predict method
Browse files
app.py
CHANGED
|
@@ -12,7 +12,7 @@ st.write('\n')
|
|
| 12 |
#show = st.image(image, use_column_width=True)
|
| 13 |
#show.image(image, 'Preloaded Image', use_column_width=True)
|
| 14 |
|
| 15 |
-
with st.spinner('Loading ViT-GPT2 model ...'):
|
| 16 |
|
| 17 |
from model import *
|
| 18 |
st.sidebar.write(f'Vit-GPT2 model loaded :)')
|
|
@@ -29,14 +29,14 @@ sample_path = os.path.join(sample_dir, sample_name)
|
|
| 29 |
|
| 30 |
image = Image.open(sample_path)
|
| 31 |
show = st.image(image, use_column_width=True)
|
| 32 |
-
show.image(image, '
|
| 33 |
|
| 34 |
# For newline
|
| 35 |
st.sidebar.write('\n')
|
| 36 |
|
| 37 |
with st.spinner('Generating image caption ...'):
|
| 38 |
|
| 39 |
-
caption =
|
| 40 |
image.close()
|
| 41 |
st.success(f'caption: {caption}')
|
| 42 |
|
|
|
|
| 12 |
#show = st.image(image, use_column_width=True)
|
| 13 |
#show.image(image, 'Preloaded Image', use_column_width=True)
|
| 14 |
|
| 15 |
+
with st.spinner('Loading and compiling ViT-GPT2 model ...'):
|
| 16 |
|
| 17 |
from model import *
|
| 18 |
st.sidebar.write(f'Vit-GPT2 model loaded :)')
|
|
|
|
| 29 |
|
| 30 |
image = Image.open(sample_path)
|
| 31 |
show = st.image(image, use_column_width=True)
|
| 32 |
+
show.image(image, '\nSelected Image', use_column_width=True)
|
| 33 |
|
| 34 |
# For newline
|
| 35 |
st.sidebar.write('\n')
|
| 36 |
|
| 37 |
with st.spinner('Generating image caption ...'):
|
| 38 |
|
| 39 |
+
caption = predict(image)
|
| 40 |
image.close()
|
| 41 |
st.success(f'caption: {caption}')
|
| 42 |
|
model.py
CHANGED
|
@@ -52,6 +52,7 @@ def predict(image):
|
|
| 52 |
|
| 53 |
token_ids = np.array(generation.sequences)[0]
|
| 54 |
caption = tokenizer.decode(token_ids)
|
|
|
|
| 55 |
|
| 56 |
return caption
|
| 57 |
|
|
|
|
| 52 |
|
| 53 |
token_ids = np.array(generation.sequences)[0]
|
| 54 |
caption = tokenizer.decode(token_ids)
|
| 55 |
+
caption = caption.replace('<s>', '').replace('</s>', '').replace('<pad>', '')
|
| 56 |
|
| 57 |
return caption
|
| 58 |
|