| import gradio as gr |
| from clip_gpt2 import CLIPGPT2, CLIPGPT2Config, CLIPGPT2Processor |
| import os |
| import torch |
|
|
| device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
| config = CLIPGPT2Config(image_from_pretrained=False, text_from_pretrained=False) |
| model = CLIPGPT2(config) |
| model.load_state_dict(torch.load("pytorch_model.bin", map_location=device)) |
| processor = CLIPGPT2Processor(config) |
|
|
| title = "Generate Image Captions With CLIP And GPT2" |
|
|
|
|
| def generate_image_captions(image, text): |
| inputs = processor(images=image, texts=text, return_tensors="pt") |
| input_ids = inputs.get("input_ids", None) |
| pixel_values = inputs.get("pixel_values", None) |
| attention_mask = inputs.get("attention_mask", None) |
| prediction = model.generate( |
| pixel_values=pixel_values, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| max_new_tokens=50 |
| ) |
| processor.tokenizer.padding_side = 'left' |
| processor.tokenizer.pad_token_id = processor.tokenizer.eos_token_id |
| prediction_text = processor.decode(prediction[0], skip_special_tokens=True) |
| |
| return prediction_text |
|
|
| article = "This demo is originated from this paper: [original paper](https://arxiv.org/abs/2209.15162)" |
| description = """ |
| ### Expand GPT2's language capabilities to vision with CLIP! |
| ### Tips: |
| - Only English is supported. |
| - When no image is provided, the model degrades to a vanilla GPT2-Large! |
| - When no description is provided, the model automatically generates a caption for the provided image. |
| - Try appending 'Answer:' after your question, the model is more likely to give desired outputs this way. |
| """ |
| demo = gr.Interface( |
| fn=generate_image_captions, |
| inputs=[ |
| gr.Image(), |
| gr.Textbox(placeholder="A picture of", lines=3) |
| ], |
| outputs="text", |
| examples=[ |
| [os.path.join(os.getcwd(), 'two_bear.png'), ""], |
| [os.path.join(os.getcwd(), 'three_women.png'), "What is the woman in the middle's dress's color? Answer:"], |
| [os.path.join(os.getcwd(), 'cat_with_food.png'), "Describe the picture:"], |
| [os.path.join(os.getcwd(), 'dog_with_frisbee.png'), "What is the color of the frisbee in the photo? Answer:"], |
| [os.path.join(os.getcwd(), 'stop_sign.png'), "What does the sign in the picture say? Answer:"] |
| ], |
| article=article, |
| title=title, |
| description=description, |
| cache_examples=False |
| ) |
|
|
| demo.launch() |