Spaces:
Runtime error
Runtime error
| import torch | |
| import re | |
| import gradio as gr | |
| from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel, ViTImageProcessor | |
| ''' | |
| device='cpu' | |
| encoder_checkpoint = "Thibalte/captionning_project" | |
| decoder_checkpoint = "Thibalte/captionning_project" | |
| model_checkpoint = "Thibalte/captionning_project" | |
| feature_extractor= ViTImageProcessor.from_pretrained(model_path) | |
| feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint) | |
| tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint) | |
| model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device) | |
| ''' | |
| # Load the trained model | |
| model_path = "Thibalte/captionning_project" | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_path) | |
| #Load ImageProcessor | |
| feature_extractor= ViTImageProcessor.from_pretrained(model_path) | |
| # Load model | |
| model = VisionEncoderDecoderModel.from_pretrained(model_path) | |
| def predict(image,max_length=24, num_beams=4): | |
| image = image.convert('RGB') | |
| pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values | |
| sequences = model.generate(pixel_values, num_beams=4, max_length=25) | |
| captions = tokenizer.batch_decode(sequences, skip_special_tokens=True) | |
| return captions | |
| # Gradio Interface | |
| gradio_app = gr.Interface( | |
| fn=predict, | |
| inputs=gr.Image(label="Select image for captioning", sources=['upload', 'webcam'], type="pil"), | |
| outputs=[gr.Textbox(label="Image Caption")], | |
| examples = [f"example{i}.jpg" for i in range(1,7)], | |
| title="Image Captioning with our model", | |
| ) | |
| if __name__ == "__main__": | |
| gradio_app.launch() |