Spaces:
Runtime error
Runtime error
| import io | |
| from PIL import Image | |
| from torchvision.transforms import CenterCrop, ConvertImageDtype, Normalize, Resize, ToTensor, Compose | |
| from torchvision.transforms.functional import InterpolationMode | |
| import torch | |
| import numpy as np | |
| from transformers import MarianTokenizer | |
| from flax_clip_vision_marian.modeling_clip_vision_marian import FlaxCLIPVisionMarianForConditionalGeneration | |
| import logging | |
| import streamlit as st | |
| from mtranslate import translate | |
| class CaptionGenerator: | |
| def __init__(self): | |
| self.tokenizer = None | |
| self.clip_marian_model = None | |
| self.marian_model_name = 'Helsinki-NLP/opus-mt-en-id' | |
| self.clip_marian_model_name = 'flax-community/Image-captioning-Indonesia' | |
| self.config = None | |
| self.image_size = None | |
| self.custom_transforms = None | |
| def load(self): | |
| logging.info("Loading tokenizer...") | |
| marian_model_name = 'Helsinki-NLP/opus-mt-en-id' | |
| self.tokenizer = MarianTokenizer.from_pretrained(self.marian_model_name) | |
| logging.info("Tokenizer loaded.") | |
| logging.info("Loading model...") | |
| self.model = FlaxCLIPVisionMarianForConditionalGeneration.from_pretrained(self.clip_marian_model_name) | |
| logging.info("Model loaded.") | |
| self.config = self.model.config | |
| self.image_size = self.config.clip_vision_config.image_size | |
| self.custom_transforms = torch.nn.Sequential( | |
| Resize([self.image_size], interpolation=InterpolationMode.BICUBIC), | |
| CenterCrop(self.image_size), | |
| ConvertImageDtype(torch.float), | |
| Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), | |
| ) | |
| def process_image(self, file): | |
| logging.info("Loading image...") | |
| image_data = file.read() | |
| input_image = Image.open(io.BytesIO(image_data)).convert("RGB") | |
| loader = Compose([ToTensor()]) | |
| image = loader(input_image) | |
| image = self.custom_transforms(image) | |
| pixel_values = torch.stack([image]).permute(0, 2, 3, 1).numpy() | |
| logging.info("Image loaded.") | |
| return pixel_values | |
| def generate_step(self, pixel_values, max_len, num_beams): | |
| gen_kwargs = {"max_length": max_len , "num_beams": num_beams} | |
| logging.info("Generating caption...") | |
| output_ids = self.model.generate(pixel_values, **gen_kwargs) | |
| token_ids = np.array(output_ids.sequences)[0] | |
| caption = self.tokenizer.decode(token_ids) | |
| logging.info("Caption generated.") | |
| return caption | |
| def get_caption(self, file, max_len, num_beams): | |
| pixel_values = self.process_image(file) | |
| generated_ids = self.generate_step(pixel_values, max_len, num_beams) | |
| return generated_ids | |
| def load_caption_generator(): | |
| generator = CaptionGenerator() | |
| generator.load() | |
| return generator | |
| def main(): | |
| st.set_page_config(page_title="Indonesian Image Captioning Demo", page_icon="🖼️") | |
| generator = load_caption_generator() | |
| st.title("Indonesian Image Captioning Demo") | |
| st.markdown( | |
| """Indonesian image captioning demo, trained on [CLIP](https://huggingface.co/transformers/model_doc/clip.html) and [Marian](https://huggingface.co/transformers/model_doc/marian.html). Part of the [Huggingface JAX/Flax event](https://discuss.huggingface.co/t/open-to-the-community-community-week-using-jax-flax-for-nlp-cv/). | |
| """ | |
| ) | |
| st.sidebar.subheader("Configurable parameters") | |
| max_len = st.sidebar.number_input( | |
| "Maximum length", | |
| value=8, | |
| help="The maximum length of the sequence (caption) to be generated." | |
| ) | |
| num_beams = st.sidebar.number_input( | |
| "Number of beams", | |
| value=4, | |
| help="Number of beams for beam search. 1 means no beam search." | |
| ) | |
| input_image = st.file_uploader("Insert image") | |
| if st.button("Run"): | |
| with st.spinner(text="Getting results..."): | |
| if input_image: | |
| caption = generator.get_caption(file=input_image, max_len=max_len, num_beams=num_beams) | |
| st.subheader("Result") | |
| st.write(caption.replace("<pad>", "")) | |
| st.text("English translation") | |
| st.write(translate(caption, "en", "id").replace("<pad>", "")) | |
| else: | |
| st.write("Please upload an image.") | |
| if __name__ == '__main__': | |
| main() | |