Spaces:
Runtime error
Runtime error
updates:
Browse files- Rename redCaps
- naming fix
- allow png
- app.py +25 -11
- model.py +3 -3
- requirements.txt +1 -0
app.py
CHANGED
|
@@ -22,10 +22,16 @@ def gen_show_caption(sub_prompt=None, cap_prompt = ""):
|
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
-
st.title("Image Captioning Demo from
|
| 26 |
st.sidebar.markdown(
|
| 27 |
"""
|
| 28 |
-
Image Captioning Model from VirTex trained on
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
"""
|
| 30 |
)
|
| 31 |
|
|
@@ -48,6 +54,15 @@ else:
|
|
| 48 |
|
| 49 |
sample_image = sample_images[0 if select_idx is None else select_idx]
|
| 50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
# class OnChange():
|
| 52 |
# def __init__(self, idx):
|
| 53 |
# self.idx = idx
|
|
@@ -75,16 +90,12 @@ else:
|
|
| 75 |
value=""
|
| 76 |
)
|
| 77 |
|
| 78 |
-
|
| 79 |
-
uploaded_image = None
|
| 80 |
-
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
|
| 81 |
-
uploaded_file = st.file_uploader("Choose a file")
|
| 82 |
-
submitted = st.form_submit_button("Submit")
|
| 83 |
-
if uploaded_file is not None and submitted:
|
| 84 |
-
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
|
| 85 |
-
select_idx = None # set this to help rewrite the cache
|
| 86 |
-
|
| 87 |
_ = st.sidebar.button("Regenerate Caption")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
if uploaded_image is None and submitted:
|
| 90 |
st.write("Please select a file to upload")
|
|
@@ -100,8 +111,11 @@ else:
|
|
| 100 |
else:
|
| 101 |
image = Image.open(image_file)
|
| 102 |
|
|
|
|
|
|
|
| 103 |
st.session_state['image'] = image
|
| 104 |
|
|
|
|
| 105 |
image_dict = imageLoader.transform(image)
|
| 106 |
|
| 107 |
show_image = imageLoader.show_resize(image)
|
|
|
|
| 22 |
)
|
| 23 |
|
| 24 |
|
| 25 |
+
st.title("Image Captioning Demo from RedCaps")
|
| 26 |
st.sidebar.markdown(
|
| 27 |
"""
|
| 28 |
+
### Image Captioning Model from VirTex trained on RedCaps
|
| 29 |
+
|
| 30 |
+
Use this page to caption your own images or try out some of our samples.
|
| 31 |
+
You can also generate captions as if they are from specific subreddits,
|
| 32 |
+
as if they start with a particular prompt, or even both.
|
| 33 |
+
|
| 34 |
+
Feel free to share your results on twitter with #redcaps or with a friend.
|
| 35 |
"""
|
| 36 |
)
|
| 37 |
|
|
|
|
| 54 |
|
| 55 |
sample_image = sample_images[0 if select_idx is None else select_idx]
|
| 56 |
|
| 57 |
+
|
| 58 |
+
uploaded_image = None
|
| 59 |
+
with st.sidebar.form("file-uploader-form", clear_on_submit=True):
|
| 60 |
+
uploaded_file = st.file_uploader("Choose a file")
|
| 61 |
+
submitted = st.form_submit_button("Submit")
|
| 62 |
+
if uploaded_file is not None and submitted:
|
| 63 |
+
uploaded_image = Image.open(io.BytesIO(uploaded_file.getvalue()))
|
| 64 |
+
select_idx = None # set this to help rewrite the cache
|
| 65 |
+
|
| 66 |
# class OnChange():
|
| 67 |
# def __init__(self, idx):
|
| 68 |
# self.idx = idx
|
|
|
|
| 90 |
value=""
|
| 91 |
)
|
| 92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 93 |
_ = st.sidebar.button("Regenerate Caption")
|
| 94 |
+
|
| 95 |
+
# advanced = st.sidebar.checkbox("Advanced Options")
|
| 96 |
+
|
| 97 |
+
# if advanced:
|
| 98 |
+
# nuc_size = st.sidebar.slider("")
|
| 99 |
|
| 100 |
if uploaded_image is None and submitted:
|
| 101 |
st.write("Please select a file to upload")
|
|
|
|
| 111 |
else:
|
| 112 |
image = Image.open(image_file)
|
| 113 |
|
| 114 |
+
image = image.convert("RGB")
|
| 115 |
+
|
| 116 |
st.session_state['image'] = image
|
| 117 |
|
| 118 |
+
|
| 119 |
image_dict = imageLoader.transform(image)
|
| 120 |
|
| 121 |
show_image = imageLoader.show_resize(image)
|
model.py
CHANGED
|
@@ -22,7 +22,7 @@ SAMPLES_PATH = "./samples/*.jpg"
|
|
| 22 |
|
| 23 |
class ImageLoader():
|
| 24 |
def __init__(self):
|
| 25 |
-
self.
|
| 26 |
torchvision.transforms.ToTensor(),
|
| 27 |
torchvision.transforms.Resize(256),
|
| 28 |
torchvision.transforms.CenterCrop(224),
|
|
@@ -30,7 +30,7 @@ class ImageLoader():
|
|
| 30 |
self.show_size=500
|
| 31 |
|
| 32 |
def load(self, im_path):
|
| 33 |
-
im = torch.FloatTensor(self.
|
| 34 |
return {"image": im}
|
| 35 |
|
| 36 |
def raw_load(self, im_path):
|
|
@@ -38,7 +38,7 @@ class ImageLoader():
|
|
| 38 |
return {"image": im}
|
| 39 |
|
| 40 |
def transform(self, image):
|
| 41 |
-
im = torch.FloatTensor(self.
|
| 42 |
return {"image": im}
|
| 43 |
|
| 44 |
def text_transform(self, text):
|
|
|
|
| 22 |
|
| 23 |
class ImageLoader():
|
| 24 |
def __init__(self):
|
| 25 |
+
self.image_transform = torchvision.transforms.Compose([
|
| 26 |
torchvision.transforms.ToTensor(),
|
| 27 |
torchvision.transforms.Resize(256),
|
| 28 |
torchvision.transforms.CenterCrop(224),
|
|
|
|
| 30 |
self.show_size=500
|
| 31 |
|
| 32 |
def load(self, im_path):
|
| 33 |
+
im = torch.FloatTensor(self.image_transform(Image.open(im_path))).unsqueeze(0)
|
| 34 |
return {"image": im}
|
| 35 |
|
| 36 |
def raw_load(self, im_path):
|
|
|
|
| 38 |
return {"image": im}
|
| 39 |
|
| 40 |
def transform(self, image):
|
| 41 |
+
im = torch.FloatTensor(self.image_transform(image)).unsqueeze(0)
|
| 42 |
return {"image": im}
|
| 43 |
|
| 44 |
def text_transform(self, text):
|
requirements.txt
CHANGED
|
@@ -14,4 +14,5 @@ torch==1.7.0
|
|
| 14 |
torchvision==0.8
|
| 15 |
tqdm>=4.50.0
|
| 16 |
wordsegment==1.3.1
|
|
|
|
| 17 |
git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
|
|
|
|
| 14 |
torchvision==0.8
|
| 15 |
tqdm>=4.50.0
|
| 16 |
wordsegment==1.3.1
|
| 17 |
+
whatimage==0.0.3
|
| 18 |
git+git://github.com/facebookresearch/fvcore.git#egg=fvcore
|