Spaces:
Build error
Build error
Add app
Browse files- .gitignore +1 -0
- README.md +0 -13
- app.py +107 -0
- fromage/__init__.py +0 -0
- fromage/models.py +658 -0
- fromage/utils.py +250 -0
.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
.DS_Store
|
README.md
DELETED
|
@@ -1,13 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
title: Fromage
|
| 3 |
-
emoji: 📚
|
| 4 |
-
colorFrom: purple
|
| 5 |
-
colorTo: yellow
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version: 3.18.0
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
license: apache-2.0
|
| 11 |
-
---
|
| 12 |
-
|
| 13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
from fromage import models
|
| 7 |
+
from fromage import utils
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import huggingface_hub
|
| 10 |
+
import tempfile
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FromageChatBot:
|
| 14 |
+
def __init__(self):
|
| 15 |
+
# Download model from HF Hub.
|
| 16 |
+
huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='pretrained_ckpt.pth.tar')
|
| 17 |
+
huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='model_args.json')
|
| 18 |
+
huggingface_hub.hf_hub_download(repo_id='jykoh/fromage', filename='cc3m_embeddings.pkl')
|
| 19 |
+
self.model = models.load_fromage('./')
|
| 20 |
+
self.chat_history = ''
|
| 21 |
+
self.input_image = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def reset(self):
|
| 25 |
+
self.chat_history = ""
|
| 26 |
+
self.input_image = None
|
| 27 |
+
return [], []
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def upload_image(self, state, image_input):
|
| 31 |
+
state += [(f"", ":)")]
|
| 32 |
+
self.input_image = Image.open(image_input.name).resize((224, 224)).convert('RGB')
|
| 33 |
+
return state, state
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def save_image_to_local(self, image: Image.Image):
|
| 37 |
+
# TODO(jykoh): Update so the url path is used, to prevent repeat saving.
|
| 38 |
+
filename = next(tempfile._get_candidate_names()) + '.png'
|
| 39 |
+
image.save(filename)
|
| 40 |
+
return filename
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def generate_for_prompt(self, input_text, state, ret_scale_factor, num_ims, num_words, temp):
|
| 44 |
+
input_prompt = 'Q: ' + input_text + '\nA:'
|
| 45 |
+
self.chat_history += input_prompt
|
| 46 |
+
|
| 47 |
+
# If an image was uploaded, prepend it to the model.
|
| 48 |
+
model_inputs = None
|
| 49 |
+
if self.input_image is not None:
|
| 50 |
+
model_inputs = [self.input_image, self.chat_history]
|
| 51 |
+
else:
|
| 52 |
+
model_inputs = [self.chat_history]
|
| 53 |
+
|
| 54 |
+
model_outputs = self.model.generate_for_images_and_texts(model_inputs, max_num_rets=num_ims, num_words=num_words, ret_scale_factor=ret_scale_factor, temperature=temp)
|
| 55 |
+
|
| 56 |
+
im_names = []
|
| 57 |
+
response = ''
|
| 58 |
+
text_outputs = []
|
| 59 |
+
for output in model_outputs:
|
| 60 |
+
if type(output) == str:
|
| 61 |
+
text_outputs.append(output)
|
| 62 |
+
response += output
|
| 63 |
+
elif type(output) == list:
|
| 64 |
+
for image in output:
|
| 65 |
+
filename = self.save_image_to_local(image)
|
| 66 |
+
response += f'<img src="/file={filename}">'
|
| 67 |
+
elif type(output) == Image.Image:
|
| 68 |
+
filename = self.save_image_to_local(output)
|
| 69 |
+
response += f'<img src="/file={filename}">'
|
| 70 |
+
|
| 71 |
+
self.chat_history += ' '.join(text_output)
|
| 72 |
+
if self.chat_history[-1] != '\n':
|
| 73 |
+
self.chat_history += '\n'
|
| 74 |
+
self.input_image = None
|
| 75 |
+
|
| 76 |
+
state.append((input_text, response))
|
| 77 |
+
return state, state
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def launch(self):
|
| 81 |
+
with gr.Blocks(css="#fromage-space {height:600px; overflow-y:auto;}") as demo:
|
| 82 |
+
chatbot = gr.Chatbot(elem_id="fromage-space")
|
| 83 |
+
gr_state = gr.State([])
|
| 84 |
+
|
| 85 |
+
with gr.Row():
|
| 86 |
+
with gr.Column(scale=0.85):
|
| 87 |
+
text_input = gr.Textbox(show_label=False, placeholder="Upload an image [optional]. Then enter a text prompt, and press enter!").style(container=False)
|
| 88 |
+
with gr.Column(scale=0.15, min_width=0):
|
| 89 |
+
image_btn = gr.UploadButton("Image", file_types=["image"])
|
| 90 |
+
|
| 91 |
+
with gr.Row():
|
| 92 |
+
with gr.Column(scale=0.20, min_width=0):
|
| 93 |
+
clear_btn = gr.Button("Clear")
|
| 94 |
+
ret_scale_factor = gr.Slider(minimum=0.0, maximum=3.0, value=1.0, step=0.1, interactive=True, label="Multiplier for returning images (higher means more frequent)")
|
| 95 |
+
max_ret_images = gr.Number(minimum=0, maximum=3, value=1, precision=1, interactive=True, label="Max images to return")
|
| 96 |
+
gr_max_len = gr.Number(value=32, precision=1, label="Max # of words returned", interactive=True)
|
| 97 |
+
gr_temperature = gr.Number(value=0.0, label="Temperature", interactive=True)
|
| 98 |
+
|
| 99 |
+
text_input.submit(self.generate_for_prompt, [text_input, gr_state, ret_scale_factor, max_ret_images, gr_max_len, gr_temperature], [gr_state, chatbot])
|
| 100 |
+
image_btn.upload(self.upload_image, [gr_state, image_btn], [gr_state, chatbot])
|
| 101 |
+
clear_btn.click(self.reset, [], [gr_state, chatbot])
|
| 102 |
+
|
| 103 |
+
demo.launch(share=False, server_name="0.0.0.0")
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
chatbot = FromageChatBot()
|
| 107 |
+
chatbot.launch()
|
fromage/__init__.py
ADDED
|
File without changes
|
fromage/models.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, List, Optional, Tuple, Union
|
| 2 |
+
from collections import namedtuple
|
| 3 |
+
import json
|
| 4 |
+
import glob
|
| 5 |
+
import math
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
import torch
|
| 9 |
+
from torch import Tensor
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from functools import partial
|
| 14 |
+
import pickle as pkl
|
| 15 |
+
from PIL import Image, UnidentifiedImageError
|
| 16 |
+
|
| 17 |
+
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 18 |
+
from transformers import OPTForCausalLM, GPT2Tokenizer
|
| 19 |
+
from transformers import CLIPVisionModel, CLIPVisionConfig
|
| 20 |
+
|
| 21 |
+
from fromage import utils
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class FrozenArgs:
|
| 25 |
+
freeze_lm: bool = True
|
| 26 |
+
freeze_vm: bool = True
|
| 27 |
+
opt_version: str = 'facebook/opt-6.7b'
|
| 28 |
+
visual_encoder: str = 'openai/clip-vit-large-patch14'
|
| 29 |
+
n_visual_tokens: int = 1
|
| 30 |
+
image_embed_dropout_prob: float = 0.0
|
| 31 |
+
task: str = 'captioning'
|
| 32 |
+
shared_emb_dim: Optional[int] = 256
|
| 33 |
+
text_emb_layers: List[int] = [-1]
|
| 34 |
+
retrieval_token_idx: int = 0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class FromageModel(nn.Module):
|
| 38 |
+
def __init__(self, tokenizer, args: FrozenArgs = FrozenArgs()):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
self.feature_extractor = utils.get_feature_extractor_for_model(args.visual_encoder, train=False)
|
| 42 |
+
self.image_token = self.tokenizer.cls_token_id
|
| 43 |
+
assert args.text_emb_layers != set(args.text_emb_layers), 'text_emb_layers not unique'
|
| 44 |
+
self.args = args
|
| 45 |
+
|
| 46 |
+
opt_version = args.opt_version
|
| 47 |
+
visual_encoder = args.visual_encoder
|
| 48 |
+
n_visual_tokens = args.n_visual_tokens
|
| 49 |
+
print(f"Using {opt_version} for the language model.")
|
| 50 |
+
print(f"Using {visual_encoder} for the visual model with {n_visual_tokens} visual tokens.")
|
| 51 |
+
|
| 52 |
+
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
|
| 53 |
+
|
| 54 |
+
if 'facebook/opt' in opt_version:
|
| 55 |
+
self.lm = OPTForCausalLM.from_pretrained(opt_version)
|
| 56 |
+
else:
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
self.opt_version = opt_version
|
| 60 |
+
|
| 61 |
+
if self.args.freeze_lm:
|
| 62 |
+
self.lm.eval()
|
| 63 |
+
print("Freezing the LM.")
|
| 64 |
+
for param in self.lm.parameters():
|
| 65 |
+
param.requires_grad = False
|
| 66 |
+
else:
|
| 67 |
+
self.lm.train()
|
| 68 |
+
|
| 69 |
+
self.retrieval_token_idx = args.retrieval_token_idx
|
| 70 |
+
print(f'Initializing embedding for the retrieval token [RET] (id = {self.retrieval_token_idx}).')
|
| 71 |
+
self.lm.resize_token_embeddings(len(tokenizer))
|
| 72 |
+
|
| 73 |
+
self.input_embeddings = self.lm.get_input_embeddings()
|
| 74 |
+
|
| 75 |
+
print("Restoring pretrained weights for the visual model.")
|
| 76 |
+
if 'clip' in visual_encoder:
|
| 77 |
+
self.visual_model = CLIPVisionModel.from_pretrained(visual_encoder)
|
| 78 |
+
else:
|
| 79 |
+
self.visual_model = AutoModel.from_pretrained(visual_encoder)
|
| 80 |
+
|
| 81 |
+
if 'clip' in visual_encoder:
|
| 82 |
+
hidden_size = self.visual_model.config.hidden_size
|
| 83 |
+
else:
|
| 84 |
+
raise NotImplementedError
|
| 85 |
+
|
| 86 |
+
if self.args.freeze_vm:
|
| 87 |
+
print("Freezing the VM.")
|
| 88 |
+
self.visual_model.eval()
|
| 89 |
+
for param in self.visual_model.parameters():
|
| 90 |
+
param.requires_grad = False
|
| 91 |
+
else:
|
| 92 |
+
self.visual_model.train()
|
| 93 |
+
|
| 94 |
+
self.visual_model_name = visual_encoder
|
| 95 |
+
|
| 96 |
+
embedding_dim = self.input_embeddings.embedding_dim * self.args.n_visual_tokens
|
| 97 |
+
self.text_hidden_fcs = nn.ModuleList([])
|
| 98 |
+
if self.args.shared_emb_dim is None:
|
| 99 |
+
if len(self.args.text_emb_layers) == 1:
|
| 100 |
+
if (self.args.text_emb_layers[0] in [-1, self.lm.config.num_hidden_layers]) and ('bert' not in opt_version):
|
| 101 |
+
out_dim = self.lm.config.word_embed_proj_dim
|
| 102 |
+
else:
|
| 103 |
+
out_dim = self.lm.config.hidden_size
|
| 104 |
+
else:
|
| 105 |
+
if (-1 in self.args.text_emb_layers) or (self.lm.config.num_hidden_layers in self.args.text_emb_layers) \
|
| 106 |
+
and (self.lm.config.word_embed_proj_dim != self.lm.config.hidden_size):
|
| 107 |
+
raise ValueError('No projection dim specified but model uses last output layer and an intermediate one (which have different dims).')
|
| 108 |
+
else:
|
| 109 |
+
out_dim = self.lm.config.hidden_size
|
| 110 |
+
else:
|
| 111 |
+
out_dim = self.args.shared_emb_dim
|
| 112 |
+
|
| 113 |
+
for layer_idx in self.args.text_emb_layers:
|
| 114 |
+
if (layer_idx == -1 or layer_idx == self.lm.config.num_hidden_layers) and ('bert' not in opt_version):
|
| 115 |
+
in_dim = self.lm.config.word_embed_proj_dim
|
| 116 |
+
|
| 117 |
+
text_fc = [nn.Linear(in_dim, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)]
|
| 118 |
+
self.text_hidden_fcs.append(nn.Sequential(*text_fc))
|
| 119 |
+
|
| 120 |
+
elif layer_idx < self.lm.config.num_hidden_layers:
|
| 121 |
+
text_fc = [nn.Linear(self.lm.config.hidden_size, out_dim), nn.Dropout(self.args.text_embed_dropout_prob)]
|
| 122 |
+
self.text_hidden_fcs.append(nn.Sequential(*text_fc))
|
| 123 |
+
else:
|
| 124 |
+
raise ValueError(f'Embedding of layer {layer_idx} was requested but model only has {self.lm.config.num_hidden_layers} layers.')
|
| 125 |
+
|
| 126 |
+
self.visual_embeddings = nn.Linear(hidden_size, embedding_dim)
|
| 127 |
+
self.visual_fc = nn.Linear(hidden_size, out_dim)
|
| 128 |
+
|
| 129 |
+
self.image_dropout = nn.Dropout(self.args.image_embed_dropout_prob)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def get_visual_embs(self, pixel_values: torch.FloatTensor, mode: str = 'captioning'):
|
| 133 |
+
if mode not in ['captioning', 'retrieval']:
|
| 134 |
+
raise ValueError(f'mode should be one of ["caption", "retrieval"], got {mode} instead.')
|
| 135 |
+
|
| 136 |
+
# Extract visual embeddings from the vision encoder.
|
| 137 |
+
if 'clip' in self.visual_model_name:
|
| 138 |
+
outputs = self.visual_model(pixel_values)
|
| 139 |
+
encoder_outputs = outputs.pooler_output
|
| 140 |
+
else:
|
| 141 |
+
raise NotImplementedError
|
| 142 |
+
|
| 143 |
+
# Use the correct fc based on function argument.
|
| 144 |
+
if mode == 'captioning':
|
| 145 |
+
visual_embs = self.visual_embeddings(encoder_outputs) # (2, D * n_visual_tokens)
|
| 146 |
+
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], self.args.n_visual_tokens, -1))
|
| 147 |
+
elif mode == 'retrieval':
|
| 148 |
+
visual_embs = self.visual_fc(encoder_outputs) # (2, D * n_visual_tokens)
|
| 149 |
+
visual_embs = torch.reshape(visual_embs, (visual_embs.shape[0], 1, -1))
|
| 150 |
+
else:
|
| 151 |
+
raise NotImplementedError
|
| 152 |
+
|
| 153 |
+
visual_embs = self.image_dropout(visual_embs)
|
| 154 |
+
return visual_embs
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def train(self, mode=True):
|
| 158 |
+
super(FromageModel, self).train(mode=mode)
|
| 159 |
+
# Overwrite train() to ensure Frozen models remain frozen.
|
| 160 |
+
if self.args.freeze_lm:
|
| 161 |
+
self.lm.eval()
|
| 162 |
+
if self.args.freeze_vm:
|
| 163 |
+
self.visual_model.eval()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def forward(
|
| 167 |
+
self,
|
| 168 |
+
pixel_values: torch.FloatTensor,
|
| 169 |
+
labels: torch.LongTensor,
|
| 170 |
+
caption_len: torch.LongTensor,
|
| 171 |
+
mode: str = 'captioning',
|
| 172 |
+
concat_captions: bool = False,
|
| 173 |
+
input_prefix: Optional[str] = None,
|
| 174 |
+
inference: bool = False,
|
| 175 |
+
):
|
| 176 |
+
visual_embs = self.get_visual_embs(pixel_values, mode)
|
| 177 |
+
|
| 178 |
+
batch_size, vis_seq_len, _ = visual_embs.shape # vis_seq_len = n_visual_tokens
|
| 179 |
+
if labels is not None:
|
| 180 |
+
assert labels.shape[0] == batch_size, (visual_embs.shape, labels.shape)
|
| 181 |
+
|
| 182 |
+
input_embs = self.input_embeddings(labels) # (N, T, D)
|
| 183 |
+
|
| 184 |
+
last_embedding_idx = caption_len - 1 # -1 to retrieve the token before the eos token
|
| 185 |
+
|
| 186 |
+
if input_prefix is not None:
|
| 187 |
+
prompt_ids = self.tokenizer(input_prefix, add_special_tokens=False, return_tensors="pt").input_ids
|
| 188 |
+
prompt_ids = prompt_ids.to(visual_embs.device)
|
| 189 |
+
prompt_embs = self.input_embeddings(prompt_ids)
|
| 190 |
+
prompt_embs = prompt_embs.repeat(batch_size, 1, 1)
|
| 191 |
+
assert prompt_embs.shape[0] == batch_size, prompt_embs.shape
|
| 192 |
+
assert prompt_embs.shape[2] == input_embs.shape[2], prompt_embs.shape
|
| 193 |
+
assert len(prompt_embs.shape) == 3, prompt_embs.shape
|
| 194 |
+
|
| 195 |
+
if mode == 'captioning':
|
| 196 |
+
# Concat to text embeddings.
|
| 197 |
+
condition_seq_len = 0
|
| 198 |
+
if input_prefix is None:
|
| 199 |
+
# Just add visual embeddings.
|
| 200 |
+
input_embs = torch.cat([visual_embs, input_embs], axis=1)
|
| 201 |
+
last_embedding_idx += vis_seq_len
|
| 202 |
+
condition_seq_len += vis_seq_len
|
| 203 |
+
full_labels = torch.zeros(visual_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
|
| 204 |
+
else:
|
| 205 |
+
# Add visual and prompt embeddings.
|
| 206 |
+
prefix_embs = torch.cat([visual_embs, prompt_embs], axis=1)
|
| 207 |
+
input_embs = torch.cat([prefix_embs, input_embs], axis=1)
|
| 208 |
+
|
| 209 |
+
last_embedding_idx += prefix_embs.shape[1]
|
| 210 |
+
condition_seq_len += prefix_embs.shape[1]
|
| 211 |
+
full_labels = torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(visual_embs.device) - 100
|
| 212 |
+
|
| 213 |
+
# Mask out embedding tokens in the labels.
|
| 214 |
+
full_labels = torch.cat([full_labels, labels], axis=1)
|
| 215 |
+
|
| 216 |
+
pad_idx = []
|
| 217 |
+
|
| 218 |
+
for label in full_labels:
|
| 219 |
+
for k, token in enumerate(label):
|
| 220 |
+
# Mask out retrieval token if it exists.
|
| 221 |
+
if token in [self.tokenizer.pad_token_id, self.retrieval_token_idx]:
|
| 222 |
+
label[k:] = -100
|
| 223 |
+
pad_idx.append(k)
|
| 224 |
+
break
|
| 225 |
+
if k == len(label) - 1: # No padding found.
|
| 226 |
+
pad_idx.append(k + 1)
|
| 227 |
+
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
|
| 228 |
+
|
| 229 |
+
bs, seq_len, embs_dim = input_embs.shape
|
| 230 |
+
if concat_captions:
|
| 231 |
+
assert len(input_embs.shape) == 3, input_embs
|
| 232 |
+
assert len(full_labels.shape) == 2, full_labels
|
| 233 |
+
assert batch_size % 2 == 0
|
| 234 |
+
all_concat_input_embs = []
|
| 235 |
+
all_concat_labels = []
|
| 236 |
+
|
| 237 |
+
# Rearrange embeddings and labels (and their padding) to concatenate captions.
|
| 238 |
+
for i in range(batch_size // 2):
|
| 239 |
+
first_idx = i * 2
|
| 240 |
+
second_idx = first_idx + 1
|
| 241 |
+
first_emb = input_embs[first_idx, :pad_idx[first_idx], :]
|
| 242 |
+
first_labels = full_labels[first_idx, :pad_idx[first_idx]]
|
| 243 |
+
first_padding = input_embs[first_idx, pad_idx[first_idx]:, :]
|
| 244 |
+
first_labels_padding = full_labels[first_idx, pad_idx[first_idx]:]
|
| 245 |
+
|
| 246 |
+
second_emb = input_embs[second_idx, :pad_idx[second_idx], :]
|
| 247 |
+
second_labels = full_labels[second_idx, :pad_idx[second_idx]]
|
| 248 |
+
second_padding = input_embs[second_idx, pad_idx[second_idx]:, :]
|
| 249 |
+
second_labels_padding = full_labels[second_idx, pad_idx[second_idx]:]
|
| 250 |
+
|
| 251 |
+
assert torch.all(first_labels_padding == -100), first_labels_padding
|
| 252 |
+
assert torch.all(second_labels_padding == -100), second_labels_padding
|
| 253 |
+
concat_input_embs = torch.cat([first_emb, second_emb, first_padding, second_padding], axis=0) # (T*2, 768)
|
| 254 |
+
concat_labels = torch.cat([first_labels, second_labels, first_labels_padding, second_labels_padding], axis=0) # (T*2, 768)
|
| 255 |
+
all_concat_input_embs.append(concat_input_embs)
|
| 256 |
+
all_concat_labels.append(concat_labels)
|
| 257 |
+
|
| 258 |
+
# Pad to max length.
|
| 259 |
+
input_embs = torch.stack(all_concat_input_embs, axis=0) # (N/2, T*2, 768)
|
| 260 |
+
full_labels = torch.stack(all_concat_labels, axis=0) # (N/2, T*2, 768)
|
| 261 |
+
assert input_embs.shape == (bs // 2, seq_len * 2, embs_dim), input_embs.shape
|
| 262 |
+
assert full_labels.shape == (bs // 2, seq_len * 2), full_labels.shape
|
| 263 |
+
|
| 264 |
+
output = self.lm(inputs_embeds=input_embs,
|
| 265 |
+
labels=full_labels,
|
| 266 |
+
output_hidden_states=True)
|
| 267 |
+
elif mode == 'retrieval':
|
| 268 |
+
full_labels = torch.clone(labels)
|
| 269 |
+
if input_prefix is not None:
|
| 270 |
+
print(f'Adding prefix "{input_prefix}" to retrieval.')
|
| 271 |
+
# Add prompt embeddings.
|
| 272 |
+
prefix_embs = prompt_embs
|
| 273 |
+
input_embs = torch.cat([prefix_embs, input_embs], axis=1)
|
| 274 |
+
last_embedding_idx += prefix_embs.shape[1]
|
| 275 |
+
full_labels = torch.cat([
|
| 276 |
+
torch.zeros(prefix_embs.shape[:2], dtype=torch.int64).to(labels.device) - 100,
|
| 277 |
+
full_labels
|
| 278 |
+
], axis=1)
|
| 279 |
+
|
| 280 |
+
pad_idx = []
|
| 281 |
+
for label in full_labels:
|
| 282 |
+
for k, token in enumerate(label):
|
| 283 |
+
if token == self.tokenizer.pad_token_id:
|
| 284 |
+
label[k:] = -100
|
| 285 |
+
pad_idx.append(k)
|
| 286 |
+
break
|
| 287 |
+
if k == len(label) - 1: # No padding found.
|
| 288 |
+
pad_idx.append(k + 1)
|
| 289 |
+
assert len(pad_idx) == batch_size, (len(pad_idx), batch_size)
|
| 290 |
+
|
| 291 |
+
output = self.lm(inputs_embeds=input_embs,
|
| 292 |
+
labels=full_labels,
|
| 293 |
+
output_hidden_states=True)
|
| 294 |
+
else:
|
| 295 |
+
raise NotImplementedError
|
| 296 |
+
|
| 297 |
+
last_embedding = None
|
| 298 |
+
last_output_logit = None
|
| 299 |
+
hidden_states = []
|
| 300 |
+
|
| 301 |
+
if mode == 'retrieval':
|
| 302 |
+
if self.args.shared_emb_dim is not None:
|
| 303 |
+
for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs):
|
| 304 |
+
hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048)
|
| 305 |
+
else:
|
| 306 |
+
for idx in self.args.text_emb_layers:
|
| 307 |
+
hidden_states.append(output.hidden_states[idx])
|
| 308 |
+
|
| 309 |
+
# Add hidden states together.
|
| 310 |
+
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1)
|
| 311 |
+
|
| 312 |
+
if not concat_captions:
|
| 313 |
+
last_embedding = torch.stack([last_hidden_state[i, last_embedding_idx[i], :] for i in range(batch_size)], axis=0) # (N, D)
|
| 314 |
+
last_output_logit = torch.stack([output.logits[i, last_embedding_idx[i] - 1, :] for i in range(batch_size)], axis=0) # (N, D)
|
| 315 |
+
else:
|
| 316 |
+
# Concatenate two captioning examples together.
|
| 317 |
+
all_last_embedding = []
|
| 318 |
+
all_last_output_logit = []
|
| 319 |
+
for i in range(batch_size // 2):
|
| 320 |
+
first_last_embedding_idx, second_last_embedding_idx = all_last_embedding_idx[i]
|
| 321 |
+
first_last_embedding = last_hidden_state[i, first_last_embedding_idx, :] # (N, D)
|
| 322 |
+
first_last_output_logit = output.logits[i, first_last_embedding_idx - 1, :] # (N, D)
|
| 323 |
+
second_last_embedding = last_hidden_state[i, second_last_embedding_idx, :] # (N, D)
|
| 324 |
+
second_last_output_logit = output.logits[i, second_last_embedding_idx - 1, :] # (N, D)
|
| 325 |
+
all_last_embedding.append(first_last_embedding)
|
| 326 |
+
all_last_embedding.append(second_last_embedding)
|
| 327 |
+
all_last_output_logit.append(first_last_output_logit)
|
| 328 |
+
all_last_output_logit.append(second_last_output_logit)
|
| 329 |
+
|
| 330 |
+
last_embedding = torch.stack(all_last_embedding)
|
| 331 |
+
last_output_logit = torch.stack(all_last_output_logit)
|
| 332 |
+
|
| 333 |
+
# Compute retrieval loss.
|
| 334 |
+
assert visual_embs.shape[1] == 1, visual_embs.shape
|
| 335 |
+
visual_embs = visual_embs[:, 0, :]
|
| 336 |
+
visual_embs = visual_embs / visual_embs.norm(dim=1, keepdim=True)
|
| 337 |
+
last_embedding = last_embedding / last_embedding.norm(dim=1, keepdim=True)
|
| 338 |
+
|
| 339 |
+
# cosine similarity as logits
|
| 340 |
+
logit_scale = self.logit_scale.exp()
|
| 341 |
+
visual_embs = logit_scale * visual_embs
|
| 342 |
+
elif mode == 'captioning':
|
| 343 |
+
pass
|
| 344 |
+
else:
|
| 345 |
+
raise NotImplementedError
|
| 346 |
+
|
| 347 |
+
return output, full_labels, last_embedding, last_output_logit, visual_embs
|
| 348 |
+
|
| 349 |
+
def generate(self, embeddings = torch.FloatTensor, max_len: int = 32,
|
| 350 |
+
temperature: float = 0.0, top_p: float = 1.0, min_word_tokens: int = 0,
|
| 351 |
+
ret_scale_factor: float = 1.0, filter_value: float = -float('Inf')):
|
| 352 |
+
"""Runs greedy decoding and returns generated captions.
|
| 353 |
+
|
| 354 |
+
Args:
|
| 355 |
+
embeddings: Input condition that the model uses for autoregressive generation.
|
| 356 |
+
max_len: Maximum number of tokens to generate.
|
| 357 |
+
temperature: Used to modulate logit distribution.
|
| 358 |
+
top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
|
| 359 |
+
min_word_tokens: Minimum number of words to generate before allowing a [RET] output.
|
| 360 |
+
ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs.
|
| 361 |
+
filter_value: Value to assign to tokens that should never be generated.
|
| 362 |
+
Outputs:
|
| 363 |
+
out: (N, T) int32 sequence of output tokens.
|
| 364 |
+
output_embeddings: (N, T, 256) sequence of text output embeddings.
|
| 365 |
+
"""
|
| 366 |
+
self.lm.eval()
|
| 367 |
+
|
| 368 |
+
with torch.no_grad(): # no tracking history
|
| 369 |
+
batch_size, s, _ = embeddings.shape
|
| 370 |
+
# init output with image tokens
|
| 371 |
+
out = None
|
| 372 |
+
past_key_values = None
|
| 373 |
+
output_embeddings = []
|
| 374 |
+
output_logits = []
|
| 375 |
+
|
| 376 |
+
for i in range(max_len):
|
| 377 |
+
if 'opt' in self.opt_version:
|
| 378 |
+
output = self.lm(inputs_embeds=embeddings, use_cache=False, output_hidden_states=True)
|
| 379 |
+
else:
|
| 380 |
+
if i == 0:
|
| 381 |
+
output = self.lm(inputs_embeds=embeddings, use_cache=True, past_key_values=None, output_hidden_states=True)
|
| 382 |
+
else:
|
| 383 |
+
output = self.lm(input_ids=out[:, -1:], use_cache=True, past_key_values=past_key_values, output_hidden_states=True)
|
| 384 |
+
|
| 385 |
+
# Collect and sum the hidden states.
|
| 386 |
+
hidden_states = []
|
| 387 |
+
if self.args.shared_emb_dim is not None:
|
| 388 |
+
for idx, fc_layer in zip(self.args.text_emb_layers, self.text_hidden_fcs):
|
| 389 |
+
hidden_states.append(fc_layer(output.hidden_states[idx])) # (N, seq_len, 2048)
|
| 390 |
+
else:
|
| 391 |
+
for idx in self.args.text_emb_layers:
|
| 392 |
+
hidden_states.append(output.hidden_states[idx])
|
| 393 |
+
# Add hidden states together.
|
| 394 |
+
last_hidden_state = torch.stack(hidden_states, dim=-1).sum(dim=-1) # (N, T, 256)
|
| 395 |
+
last_embedding = last_hidden_state / last_hidden_state.norm(dim=-1, keepdim=True)
|
| 396 |
+
output_embeddings.append(last_embedding)
|
| 397 |
+
|
| 398 |
+
logits = output.logits[:, -1, :] # (N, vocab_size)
|
| 399 |
+
if top_p == 1.0:
|
| 400 |
+
logits = logits.cpu()
|
| 401 |
+
output_logits.append(logits)
|
| 402 |
+
|
| 403 |
+
if self.retrieval_token_idx != -1 and self.retrieval_token_idx is not None:
|
| 404 |
+
if i < min_word_tokens:
|
| 405 |
+
# Eliminate probability of generating [RET] if this is earlier than min_word_tokens.
|
| 406 |
+
logits[:, self.retrieval_token_idx] = filter_value
|
| 407 |
+
else:
|
| 408 |
+
# Multiply by scaling factor.
|
| 409 |
+
logits[:, self.retrieval_token_idx] = logits[:, self.retrieval_token_idx] * ret_scale_factor
|
| 410 |
+
|
| 411 |
+
past_key_values = output.past_key_values
|
| 412 |
+
|
| 413 |
+
if temperature == 0.0:
|
| 414 |
+
if top_p != 1.0:
|
| 415 |
+
raise ValueError('top_p cannot be set if temperature is 0 (greedy decoding).')
|
| 416 |
+
next_token = torch.argmax(logits, keepdim=True, dim=-1) # (N, 1)
|
| 417 |
+
else:
|
| 418 |
+
logits = logits / temperature
|
| 419 |
+
|
| 420 |
+
# Apply top-p filtering.
|
| 421 |
+
if top_p < 1.0:
|
| 422 |
+
assert top_p > 0, f'top_p should be above 0, got {top_p} instead.'
|
| 423 |
+
sorted_logits, sorted_indices = torch.sort(logits, descending=True) # (N, D) and (N, D)
|
| 424 |
+
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # (N, D)
|
| 425 |
+
|
| 426 |
+
# Remove tokens with cumulative probability above the threshold
|
| 427 |
+
sorted_indices_to_remove = cumulative_probs > top_p
|
| 428 |
+
# Shift the indices to the right to keep also the first token above the threshold
|
| 429 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 430 |
+
sorted_indices_to_remove[..., 0] = 0
|
| 431 |
+
|
| 432 |
+
for j in range(sorted_indices.shape[0]):
|
| 433 |
+
indices_to_remove = sorted_indices[j, sorted_indices_to_remove[j, :]]
|
| 434 |
+
logits[j, indices_to_remove] = filter_value
|
| 435 |
+
|
| 436 |
+
token_weights = logits.exp() # (N, vocab_size)
|
| 437 |
+
next_token = torch.multinomial(token_weights, 1) # (N, 1)
|
| 438 |
+
|
| 439 |
+
next_token = next_token.long().to(embeddings.device)
|
| 440 |
+
if out is not None:
|
| 441 |
+
out = torch.cat([out, next_token], dim=-1)
|
| 442 |
+
else:
|
| 443 |
+
out = next_token
|
| 444 |
+
|
| 445 |
+
if 'opt' in self.opt_version:
|
| 446 |
+
next_embedding = self.input_embeddings(next_token)
|
| 447 |
+
embeddings = torch.cat([embeddings, next_embedding], dim=1)
|
| 448 |
+
elif (self.tokenizer.eos_token_id and (next_token == self.tokenizer.eos_token_id).all()):
|
| 449 |
+
# End of generation.
|
| 450 |
+
break
|
| 451 |
+
|
| 452 |
+
return out, output_embeddings, output_logits
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
class Fromage(nn.Module):
|
| 456 |
+
def __init__(self, tokenizer, model_args: Optional[FrozenArgs] = None,
|
| 457 |
+
path_array: Optional[List[str]] = None, emb_matrix: Optional[torch.tensor] = None):
|
| 458 |
+
super().__init__()
|
| 459 |
+
self.model = FromageModel(tokenizer, model_args)
|
| 460 |
+
self.path_array = path_array
|
| 461 |
+
self.emb_matrix = emb_matrix
|
| 462 |
+
|
| 463 |
+
def __call__(self, images: Tensor, tgt_tokens: Optional[Tensor] = None, caption_len: Optional[Tensor] = None,
|
| 464 |
+
generate: bool = False, num_words: int = 32, temperature: float = 1.0, top_p: float = 1.0,
|
| 465 |
+
ret_scale_factor: float = 1.0, min_word_tokens: int = 0,
|
| 466 |
+
mode: str = 'captioning', concat_captions: bool = False,
|
| 467 |
+
input_prefix: Optional[str] = None, inference: bool = False) -> Tensor:
|
| 468 |
+
if generate:
|
| 469 |
+
return self.model.generate(images, num_words, temperature=temperature, top_p=top_p,
|
| 470 |
+
min_word_tokens=min_word_tokens, ret_scale_factor=ret_scale_factor)
|
| 471 |
+
else:
|
| 472 |
+
output = self.model(
|
| 473 |
+
pixel_values = images,
|
| 474 |
+
labels = tgt_tokens,
|
| 475 |
+
caption_len = caption_len,
|
| 476 |
+
mode = mode,
|
| 477 |
+
concat_captions = concat_captions,
|
| 478 |
+
input_prefix = input_prefix,
|
| 479 |
+
inference = inference)
|
| 480 |
+
return output
|
| 481 |
+
|
| 482 |
+
def generate_for_images_and_texts(
|
| 483 |
+
self, prompts: List, num_words: int = 0, ret_scale_factor: float = 1.0, top_p: float = 1.0, temperature: float = 0.0,
|
| 484 |
+
max_num_rets: int = 1):
|
| 485 |
+
"""
|
| 486 |
+
Encode prompts into embeddings.
|
| 487 |
+
|
| 488 |
+
Args:
|
| 489 |
+
prompts: List of interleaved PIL.Image.Image and strings representing input to the model.
|
| 490 |
+
num_words: Maximum number of words to generate for. If num_words = 0, the model will run its forward pass and return the outputs.
|
| 491 |
+
ret_scale_factor: Proportion to scale [RET] token logits by. A higher value may increase the probability of the model generating [RET] outputs.
|
| 492 |
+
top_p: If set to < 1, the smallest set of tokens with highest probabilities that add up to top_p or higher are kept for generation.
|
| 493 |
+
temperature: Used to modulate logit distribution.
|
| 494 |
+
max_num_rets: Maximum number of images to return in one generation pass.
|
| 495 |
+
Returns:
|
| 496 |
+
return_outputs: List consisting of either str or List[PIL.Image.Image] objects, representing image-text interleaved model outputs.
|
| 497 |
+
"""
|
| 498 |
+
input_embs = []
|
| 499 |
+
input_ids = []
|
| 500 |
+
add_bos = True
|
| 501 |
+
|
| 502 |
+
for i, p in enumerate(prompts):
|
| 503 |
+
if type(p) == Image.Image:
|
| 504 |
+
# Encode as image.
|
| 505 |
+
pixel_values = utils.get_pixel_values_for_model(self.model.feature_extractor, p)
|
| 506 |
+
pixel_values = pixel_values.to(device=self.model.logit_scale.device, dtype=self.model.logit_scale.dtype)
|
| 507 |
+
pixel_values = pixel_values[None, ...]
|
| 508 |
+
|
| 509 |
+
visual_embs = self.model.get_visual_embs(pixel_values, mode='captioning') # (1, n_visual_tokens, D)
|
| 510 |
+
input_embs.append(visual_embs)
|
| 511 |
+
elif type(p) == str:
|
| 512 |
+
text_ids = self.model.tokenizer(p, add_special_tokens=True, return_tensors="pt").input_ids.to(self.model.logit_scale.device)
|
| 513 |
+
if not add_bos:
|
| 514 |
+
# Remove <bos> tag.
|
| 515 |
+
text_ids = text_ids[:, 1:]
|
| 516 |
+
else:
|
| 517 |
+
# Only add <bos> once.
|
| 518 |
+
add_bos = False
|
| 519 |
+
|
| 520 |
+
text_embs = self.model.input_embeddings(text_ids) # (1, T, D)
|
| 521 |
+
input_embs.append(text_embs)
|
| 522 |
+
input_ids.append(text_ids)
|
| 523 |
+
else:
|
| 524 |
+
raise ValueError(f'Input prompts should be either PIL.Image.Image or str types, got {type(p)} instead.')
|
| 525 |
+
input_embs = torch.cat(input_embs, dim=1)
|
| 526 |
+
input_ids = torch.cat(input_ids, dim=1)
|
| 527 |
+
|
| 528 |
+
if num_words == 0:
|
| 529 |
+
generated_ids = input_ids
|
| 530 |
+
outputs = self.model.lm(inputs_embeds=input_embs, use_cache=False, output_hidden_states=True)
|
| 531 |
+
# Map outputs to embeddings, so we can retrieve embeddings from the [RET] tokens.
|
| 532 |
+
out = []
|
| 533 |
+
for x, fc in zip(self.model.args.text_emb_layers, self.model.text_hidden_fcs):
|
| 534 |
+
out.append(fc(outputs.hidden_states[x]))
|
| 535 |
+
embeddings = torch.stack(out, dim=-1).sum(dim=-1)
|
| 536 |
+
embeddings = embeddings / embeddings.norm(dim=-1, keepdim=True) # (N, T, 256)
|
| 537 |
+
elif num_words > 0:
|
| 538 |
+
generated_ids, generated_embeddings, _ = self.model.generate(input_embs, num_words,
|
| 539 |
+
temperature=temperature, top_p=top_p, ret_scale_factor=ret_scale_factor)
|
| 540 |
+
embeddings = generated_embeddings[-1][:, input_embs.shape[1]:]
|
| 541 |
+
|
| 542 |
+
# Truncate to newline.
|
| 543 |
+
newline_token_id = self.model.tokenizer('\n', add_special_tokens=False).input_ids[0]
|
| 544 |
+
trunc_idx = 0
|
| 545 |
+
for j in range(generated_ids.shape[1]):
|
| 546 |
+
if generated_ids[0, j] == newline_token_id:
|
| 547 |
+
trunc_idx = j
|
| 548 |
+
break
|
| 549 |
+
if trunc_idx > 0:
|
| 550 |
+
generated_ids = generated_ids[:, :trunc_idx]
|
| 551 |
+
embeddings = embeddings[:, :trunc_idx]
|
| 552 |
+
else:
|
| 553 |
+
raise ValueError
|
| 554 |
+
|
| 555 |
+
# Save outputs as an interleaved list.
|
| 556 |
+
return_outputs = []
|
| 557 |
+
# Find up to max_num_rets [RET] tokens, and their corresponding scores.
|
| 558 |
+
all_ret_idx = [i for i, x in enumerate(generated_ids[0, :] == self.model.retrieval_token_idx) if x][:max_num_rets]
|
| 559 |
+
seen_image_idx = [] # Avoid showing the same image multiple times.
|
| 560 |
+
|
| 561 |
+
last_ret_idx = 0
|
| 562 |
+
if len(all_ret_idx) == 0:
|
| 563 |
+
# No [RET] tokens.
|
| 564 |
+
caption = self.model.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
| 565 |
+
return_outputs.append(utils.truncate_caption(caption))
|
| 566 |
+
else:
|
| 567 |
+
for ret_idx in all_ret_idx:
|
| 568 |
+
ret_emb = embeddings[:, ret_idx, :]
|
| 569 |
+
scores = self.emb_matrix @ ret_emb.T
|
| 570 |
+
|
| 571 |
+
# Downweight seen images.
|
| 572 |
+
for seen_idx in seen_image_idx:
|
| 573 |
+
scores[seen_idx, :] -= 1000
|
| 574 |
+
|
| 575 |
+
# Get the top 3 images for each image.
|
| 576 |
+
_, top_image_idx = scores.squeeze().topk(3)
|
| 577 |
+
image_outputs = []
|
| 578 |
+
for img_idx in top_image_idx:
|
| 579 |
+
# Find the first image that does not error out.
|
| 580 |
+
try:
|
| 581 |
+
seen_image_idx.append(img_idx)
|
| 582 |
+
img = utils.get_image_from_url(self.path_array[img_idx])
|
| 583 |
+
image_outputs.append(img)
|
| 584 |
+
if len(image_outputs) == max_num_rets:
|
| 585 |
+
break
|
| 586 |
+
except UnidentifiedImageError:
|
| 587 |
+
pass
|
| 588 |
+
|
| 589 |
+
caption = self.model.tokenizer.batch_decode(generated_ids[:, last_ret_idx:ret_idx], skip_special_tokens=True)[0]
|
| 590 |
+
last_ret_idx = ret_idx + 1
|
| 591 |
+
return_outputs.append(utils.truncate_caption(caption) + ' [RET]')
|
| 592 |
+
return_outputs.append(image_outputs)
|
| 593 |
+
|
| 594 |
+
return return_outputs
|
| 595 |
+
|
| 596 |
+
|
| 597 |
+
def load_fromage(model_dir: str) -> Fromage:
|
| 598 |
+
model_args_path = os.path.join(model_dir, 'model_args.json')
|
| 599 |
+
model_ckpt_path = os.path.join(model_dir, 'pretrained_ckpt.pth.tar')
|
| 600 |
+
embs_paths = [s for s in glob.glob(os.path.join(model_dir, 'cc3m_embeddings*.pkl'))]
|
| 601 |
+
|
| 602 |
+
if not os.path.exists(model_args_path):
|
| 603 |
+
raise ValueError(f'model_args.json does not exist in {model_dir}.')
|
| 604 |
+
if not os.path.exists(model_ckpt_path):
|
| 605 |
+
raise ValueError(f'pretrained_ckpt.pth.tar does not exist in {model_dir}.')
|
| 606 |
+
if len(embs_paths) == 0:
|
| 607 |
+
raise ValueError(f'cc3m_embeddings_*.pkl files do not exist in {model_dir}.')
|
| 608 |
+
|
| 609 |
+
# Load embeddings.
|
| 610 |
+
# Construct embedding matrix for nearest neighbor lookup.
|
| 611 |
+
path_array = []
|
| 612 |
+
emb_matrix = []
|
| 613 |
+
|
| 614 |
+
# These were precomputed for all CC3M images with `model.get_visual_embs(image, mode='retrieval')`.
|
| 615 |
+
for p in embs_paths:
|
| 616 |
+
with open(p, 'rb') as wf:
|
| 617 |
+
train_embs_data = pkl.load(wf)
|
| 618 |
+
path_array.extend(train_embs_data['paths'])
|
| 619 |
+
emb_matrix.append(train_embs_data['embeddings'])
|
| 620 |
+
emb_matrix = np.concatenate(emb_matrix, axis=0)
|
| 621 |
+
|
| 622 |
+
# Number of paths should be equal to number of embeddings.
|
| 623 |
+
assert len(path_array) == emb_matrix.shape[0], (len(path_array), emb_matrix.shape[0])
|
| 624 |
+
|
| 625 |
+
with open(model_args_path, 'r') as f:
|
| 626 |
+
model_kwargs = json.load(f)
|
| 627 |
+
|
| 628 |
+
# Initialize tokenizer.
|
| 629 |
+
tokenizer = GPT2Tokenizer.from_pretrained(model_kwargs['opt_version'])
|
| 630 |
+
tokenizer.pad_token = tokenizer.eos_token
|
| 631 |
+
# Add special tokens to the model to enable [RET].
|
| 632 |
+
tokenizer.add_special_tokens({"cls_token": "<|image|>"})
|
| 633 |
+
tokenizer.add_tokens('[RET]')
|
| 634 |
+
ret_token_idx = tokenizer('[RET]', add_special_tokens=False).input_ids
|
| 635 |
+
assert len(ret_token_idx) == 1, ret_token_idx
|
| 636 |
+
model_kwargs['retrieval_token_idx'] = ret_token_idx[0]
|
| 637 |
+
args = namedtuple('args', model_kwargs)(**model_kwargs)
|
| 638 |
+
|
| 639 |
+
# Initialize model for inference.
|
| 640 |
+
model = Fromage(tokenizer, args, path_array=path_array, emb_matrix=emb_matrix)
|
| 641 |
+
model = model.eval()
|
| 642 |
+
model = model.bfloat16()
|
| 643 |
+
model = model.cuda()
|
| 644 |
+
|
| 645 |
+
# Load pretrained linear mappings and [RET] embeddings.
|
| 646 |
+
checkpoint = torch.load(model_ckpt_path)
|
| 647 |
+
model.load_state_dict(checkpoint['state_dict'], strict=False)
|
| 648 |
+
with torch.no_grad():
|
| 649 |
+
model.model.input_embeddings.weight[model.model.retrieval_token_idx, :].copy_(checkpoint['state_dict']['ret_input_embeddings.weight'].cpu().detach())
|
| 650 |
+
|
| 651 |
+
logit_scale = model.model.logit_scale.exp()
|
| 652 |
+
emb_matrix = torch.tensor(emb_matrix, dtype=logit_scale.dtype).to(logit_scale.device)
|
| 653 |
+
emb_matrix = emb_matrix / emb_matrix.norm(dim=1, keepdim=True)
|
| 654 |
+
emb_matrix = logit_scale * emb_matrix
|
| 655 |
+
model.emb_matrix = emb_matrix
|
| 656 |
+
|
| 657 |
+
return model
|
| 658 |
+
|
fromage/utils.py
ADDED
|
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
import subprocess
|
| 3 |
+
import sys
|
| 4 |
+
import shutil
|
| 5 |
+
import torch
|
| 6 |
+
import torch.distributed as dist
|
| 7 |
+
from torchvision.transforms import functional as F
|
| 8 |
+
from torchvision import transforms as T
|
| 9 |
+
from transformers import AutoFeatureExtractor
|
| 10 |
+
from PIL import Image, ImageDraw, ImageFont, ImageOps
|
| 11 |
+
import requests
|
| 12 |
+
from io import BytesIO
|
| 13 |
+
|
| 14 |
+
import random
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def dump_git_status(out_file=sys.stdout, exclude_file_patterns=['*.ipynb', '*.th', '*.sh', '*.txt', '*.json']):
|
| 18 |
+
"""Logs git status to stdout."""
|
| 19 |
+
subprocess.call('git rev-parse HEAD', shell=True, stdout=out_file)
|
| 20 |
+
subprocess.call('echo', shell=True, stdout=out_file)
|
| 21 |
+
exclude_string = ''
|
| 22 |
+
subprocess.call('git --no-pager diff -- . {}'.format(exclude_string), shell=True, stdout=out_file)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_image_from_url(url: str):
|
| 26 |
+
response = requests.get(url)
|
| 27 |
+
img = Image.open(BytesIO(response.content))
|
| 28 |
+
img = img.resize((224, 224))
|
| 29 |
+
img = img.convert('RGB')
|
| 30 |
+
return img
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def truncate_caption(caption: str) -> str:
|
| 34 |
+
"""Truncate captions at periods and newlines."""
|
| 35 |
+
trunc_index = caption.find('\n') + 1
|
| 36 |
+
if trunc_index <= 0:
|
| 37 |
+
trunc_index = caption.find('.') + 1
|
| 38 |
+
caption = caption[:trunc_index]
|
| 39 |
+
return caption
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def pad_to_size(x, size=256):
|
| 43 |
+
delta_w = size - x.size[0]
|
| 44 |
+
delta_h = size - x.size[1]
|
| 45 |
+
padding = (
|
| 46 |
+
delta_w // 2,
|
| 47 |
+
delta_h // 2,
|
| 48 |
+
delta_w - (delta_w // 2),
|
| 49 |
+
delta_h - (delta_h // 2),
|
| 50 |
+
)
|
| 51 |
+
new_im = ImageOps.expand(x, padding)
|
| 52 |
+
return new_im
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class RandCropResize(object):
|
| 56 |
+
|
| 57 |
+
"""
|
| 58 |
+
Randomly crops, then randomly resizes, then randomly crops again, an image. Mirroring the augmentations from https://arxiv.org/abs/2102.12092
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
def __init__(self, target_size):
|
| 62 |
+
self.target_size = target_size
|
| 63 |
+
|
| 64 |
+
def __call__(self, img):
|
| 65 |
+
img = pad_to_size(img, self.target_size)
|
| 66 |
+
d_min = min(img.size)
|
| 67 |
+
img = T.RandomCrop(size=d_min)(img)
|
| 68 |
+
t_min = min(d_min, round(9 / 8 * self.target_size))
|
| 69 |
+
t_max = min(d_min, round(12 / 8 * self.target_size))
|
| 70 |
+
t = random.randint(t_min, t_max + 1)
|
| 71 |
+
img = T.Resize(t)(img)
|
| 72 |
+
if min(img.size) < 256:
|
| 73 |
+
img = T.Resize(256)(img)
|
| 74 |
+
return T.RandomCrop(size=self.target_size)(img)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class SquarePad(object):
|
| 78 |
+
"""Pads image to square.
|
| 79 |
+
From https://discuss.pytorch.org/t/how-to-resize-and-pad-in-a-torchvision-transforms-compose/71850/9
|
| 80 |
+
"""
|
| 81 |
+
def __call__(self, image):
|
| 82 |
+
max_wh = max(image.size)
|
| 83 |
+
p_left, p_top = [(max_wh - s) // 2 for s in image.size]
|
| 84 |
+
p_right, p_bottom = [max_wh - (s+pad) for s, pad in zip(image.size, [p_left, p_top])]
|
| 85 |
+
padding = (p_left, p_top, p_right, p_bottom)
|
| 86 |
+
return F.pad(image, padding, 0, 'constant')
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def create_image_of_text(text: str, width: int = 224, nrows: int = 2, color=(255, 255, 255), font=None) -> torch.Tensor:
|
| 90 |
+
"""Creates a (3, nrows * 14, width) image of text.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
cap_img: (3, 14 * nrows, width) image of wrapped text.
|
| 94 |
+
"""
|
| 95 |
+
height = 12
|
| 96 |
+
padding = 5
|
| 97 |
+
effective_width = width - 2 * padding
|
| 98 |
+
# Create a black image to draw text on.
|
| 99 |
+
cap_img = Image.new('RGB', (effective_width * nrows, height), color = (0, 0, 0))
|
| 100 |
+
draw = ImageDraw.Draw(cap_img)
|
| 101 |
+
draw.text((0, 0), text, color, font=font or ImageFont.load_default())
|
| 102 |
+
cap_img = F.convert_image_dtype(F.pil_to_tensor(cap_img), torch.float32) # (3, height, W * nrows)
|
| 103 |
+
cap_img = torch.split(cap_img, effective_width, dim=-1) # List of nrow elements of shape (3, height, W)
|
| 104 |
+
cap_img = torch.cat(cap_img, dim=1) # (3, height * nrows, W)
|
| 105 |
+
# Add zero padding.
|
| 106 |
+
cap_img = torch.nn.functional.pad(cap_img, [padding, padding, 0, padding])
|
| 107 |
+
return cap_img
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def get_feature_extractor_for_model(model_name: str, image_size: int = 224, train: bool = True):
|
| 111 |
+
print(f'Using HuggingFace AutoFeatureExtractor for {model_name}.')
|
| 112 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(model_name)
|
| 113 |
+
return feature_extractor
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def get_pixel_values_for_model(feature_extractor, img):
|
| 117 |
+
pixel_values = feature_extractor(
|
| 118 |
+
img.convert('RGB'),
|
| 119 |
+
return_tensors="pt").pixel_values[0, ...] # (3, H, W)
|
| 120 |
+
return pixel_values
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def save_checkpoint(state, is_best, filename='checkpoint'):
|
| 124 |
+
torch.save(state, filename + '.pth.tar')
|
| 125 |
+
if is_best:
|
| 126 |
+
shutil.copyfile(filename + '.pth.tar', filename + '_best.pth.tar')
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def accuracy(output, target, padding, topk=(1,)):
|
| 130 |
+
"""Computes the accuracy over the k top predictions for the specified values of k"""
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
maxk = max(topk)
|
| 133 |
+
if output.shape[-1] < maxk:
|
| 134 |
+
print(f"[WARNING] Less than {maxk} predictions available. Using {output.shape[-1]} for topk.")
|
| 135 |
+
|
| 136 |
+
maxk = min(maxk, output.shape[-1])
|
| 137 |
+
batch_size = target.size(0)
|
| 138 |
+
|
| 139 |
+
# Take topk along the last dimension.
|
| 140 |
+
_, pred = output.topk(maxk, -1, True, True) # (N, T, topk)
|
| 141 |
+
|
| 142 |
+
mask = (target != padding).type(target.dtype)
|
| 143 |
+
target_expand = target[..., None].expand_as(pred)
|
| 144 |
+
correct = pred.eq(target_expand)
|
| 145 |
+
correct = correct * mask[..., None].expand_as(correct)
|
| 146 |
+
|
| 147 |
+
res = []
|
| 148 |
+
for k in topk:
|
| 149 |
+
correct_k = correct[..., :k].reshape(-1).float().sum(0, keepdim=True)
|
| 150 |
+
res.append(correct_k.mul_(100.0 / mask.sum()))
|
| 151 |
+
return res
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_params_count(model, max_name_len: int = 60):
|
| 155 |
+
params = [(name[:max_name_len], p.numel(), str(tuple(p.shape)), p.requires_grad) for name, p in model.named_parameters()]
|
| 156 |
+
total_trainable_params = sum([x[1] for x in params if x[-1]])
|
| 157 |
+
total_nontrainable_params = sum([x[1] for x in params if not x[-1]])
|
| 158 |
+
return params, total_trainable_params, total_nontrainable_params
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def get_params_count_str(model, max_name_len: int = 60):
|
| 162 |
+
padding = 70 # Hardcoded depending on desired amount of padding and separators.
|
| 163 |
+
params, total_trainable_params, total_nontrainable_params = get_params_count(model, max_name_len)
|
| 164 |
+
param_counts_text = ''
|
| 165 |
+
param_counts_text += '=' * (max_name_len + padding) + '\n'
|
| 166 |
+
param_counts_text += f'| {"Module":<{max_name_len}} | {"Trainable":<10} | {"Shape":>15} | {"Param Count":>12} |\n'
|
| 167 |
+
param_counts_text += '-' * (max_name_len + padding) + '\n'
|
| 168 |
+
for name, param_count, shape, trainable in params:
|
| 169 |
+
param_counts_text += f'| {name:<{max_name_len}} | {"True" if trainable else "False":<10} | {shape:>15} | {param_count:>12,} |\n'
|
| 170 |
+
param_counts_text += '-' * (max_name_len + padding) + '\n'
|
| 171 |
+
param_counts_text += f'| {"Total trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_trainable_params:>12,} |\n'
|
| 172 |
+
param_counts_text += f'| {"Total non-trainable params":<{max_name_len}} | {"":<10} | {"":<15} | {total_nontrainable_params:>12,} |\n'
|
| 173 |
+
param_counts_text += '=' * (max_name_len + padding) + '\n'
|
| 174 |
+
return param_counts_text
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
class Summary(Enum):
|
| 178 |
+
NONE = 0
|
| 179 |
+
AVERAGE = 1
|
| 180 |
+
SUM = 2
|
| 181 |
+
COUNT = 3
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ProgressMeter(object):
|
| 185 |
+
def __init__(self, num_batches, meters, prefix=""):
|
| 186 |
+
self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
|
| 187 |
+
self.meters = meters
|
| 188 |
+
self.prefix = prefix
|
| 189 |
+
|
| 190 |
+
def display(self, batch):
|
| 191 |
+
entries = [self.prefix + self.batch_fmtstr.format(batch)]
|
| 192 |
+
entries += [str(meter) for meter in self.meters]
|
| 193 |
+
print('\t'.join(entries))
|
| 194 |
+
|
| 195 |
+
def display_summary(self):
|
| 196 |
+
entries = [" *"]
|
| 197 |
+
entries += [meter.summary() for meter in self.meters]
|
| 198 |
+
print(' '.join(entries))
|
| 199 |
+
|
| 200 |
+
def _get_batch_fmtstr(self, num_batches):
|
| 201 |
+
num_digits = len(str(num_batches // 1))
|
| 202 |
+
fmt = '{:' + str(num_digits) + 'd}'
|
| 203 |
+
return '[' + fmt + '/' + fmt.format(num_batches) + ']'
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
class AverageMeter(object):
|
| 207 |
+
"""Computes and stores the average and current value"""
|
| 208 |
+
def __init__(self, name, fmt=':f', summary_type=Summary.AVERAGE):
|
| 209 |
+
self.name = name
|
| 210 |
+
self.fmt = fmt
|
| 211 |
+
self.summary_type = summary_type
|
| 212 |
+
self.reset()
|
| 213 |
+
|
| 214 |
+
def reset(self):
|
| 215 |
+
self.val = 0
|
| 216 |
+
self.avg = 0
|
| 217 |
+
self.sum = 0
|
| 218 |
+
self.count = 0
|
| 219 |
+
|
| 220 |
+
def update(self, val, n=1):
|
| 221 |
+
self.val = val
|
| 222 |
+
self.sum += val * n
|
| 223 |
+
self.count += n
|
| 224 |
+
self.avg = self.sum / self.count
|
| 225 |
+
|
| 226 |
+
def all_reduce(self):
|
| 227 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 228 |
+
total = torch.tensor([self.sum, self.count], dtype=torch.float32, device=device)
|
| 229 |
+
dist.all_reduce(total, dist.ReduceOp.SUM, async_op=False)
|
| 230 |
+
self.sum, self.count = total.tolist()
|
| 231 |
+
self.avg = self.sum / self.count
|
| 232 |
+
|
| 233 |
+
def __str__(self):
|
| 234 |
+
fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
|
| 235 |
+
return fmtstr.format(**self.__dict__)
|
| 236 |
+
|
| 237 |
+
def summary(self):
|
| 238 |
+
fmtstr = ''
|
| 239 |
+
if self.summary_type is Summary.NONE:
|
| 240 |
+
fmtstr = ''
|
| 241 |
+
elif self.summary_type is Summary.AVERAGE:
|
| 242 |
+
fmtstr = '{name} {avg:.3f}'
|
| 243 |
+
elif self.summary_type is Summary.SUM:
|
| 244 |
+
fmtstr = '{name} {sum:.3f}'
|
| 245 |
+
elif self.summary_type is Summary.COUNT:
|
| 246 |
+
fmtstr = '{name} {count:.3f}'
|
| 247 |
+
else:
|
| 248 |
+
raise ValueError('invalid summary type %r' % self.summary_type)
|
| 249 |
+
|
| 250 |
+
return fmtstr.format(**self.__dict__)
|