import argparse import dataclasses import json from typing import Generator, Any from transformers import T5TokenizerFast import numpy as np import torch from pix2struct.modeling import Pix2StructModel from pix2struct.processing import extract_patches def ask_generator(tokenizer, question, max_length=256): end_token_id = tokenizer.convert_tokens_to_ids([''])[0] input_ids = [ *tokenizer.convert_tokens_to_ids(['']), *tokenizer.encode(question, add_special_tokens=False), *tokenizer.convert_tokens_to_ids(['▁']), ] generated_token_ids = [] too_long = False while True: logits = yield input_ids next_token_id = torch.argmax(logits).item() if next_token_id == end_token_id: break if len(generated_token_ids) >= max_length: too_long = True break generated_token_ids.append(next_token_id) input_ids = [next_token_id] if too_long: return '' return tokenizer.decode(generated_token_ids) @dataclasses.dataclass class DocumentQuery: meta: Any generator: Generator output: Any = None @dataclasses.dataclass class DocumentQueries: meta: Any patches: torch.Tensor queries: [DocumentQuery] def debug(*x): pass # print(*x) def generate( model: Pix2StructModel, documents: [DocumentQueries], device: torch.device, init_cache_size: int = 512, ) -> [DocumentQueries]: documents_patches = [document.patches for document in documents] documents_patches_lens = [patches.size(0) for patches in documents_patches] documents_patches = torch.cat(documents_patches, dim=0).to(device) documents_patches_cu_seq_lens = torch.tensor( [0, *np.cumsum(documents_patches_lens)], dtype=torch.int32, device=device, ) documents_patches_max_seq_len = max(documents_patches_lens) encoder_cache = model.get_encoder_kv_cache( flattened_patches=documents_patches, flattened_patches_cu_seq_lens=documents_patches_cu_seq_lens, flattened_patches_max_seq_len=documents_patches_max_seq_len, ) total_queries = sum(len(document.queries) for document in documents) decoder_k_cache, decoder_v_cache = model.decoder.get_decoder_kv_cache( device, total_queries, init_cache_size, dtype=torch.bfloat16, ) decoder_cache_seqlens = torch.zeros((total_queries,), dtype=torch.int32, device=device) input_ids = [] encoder_cache_batch_idx = [] encoder_cache_seqlens = [] for doc_idx, document in enumerate(documents): for query in document.queries: if query.output is None: input_ids.append(next(query.generator)) encoder_cache_batch_idx.append(doc_idx) encoder_cache_seqlens.append(encoder_cache['encoder_cache_seqlens'][doc_idx]) input_ids_lens = [len(ids) for ids in input_ids] input_ids_max_seq_len = max(input_ids_lens) input_ids = [ids + [0] * (input_ids_max_seq_len - len(ids)) for ids in input_ids] input_ids = torch.tensor(input_ids, dtype=torch.long).to(device) encoder_cache_batch_idx = torch.tensor(encoder_cache_batch_idx, dtype=torch.int32).to(device) encoder_cache_seqlens = torch.tensor(encoder_cache_seqlens, dtype=torch.int32).to(device) while any(query.output is None for document in documents for query in document.queries): debug('Generating') debug('input_ids', input_ids) debug('input_ids_lens', input_ids_lens) debug('decoder_k_cache', decoder_k_cache[0].size(), decoder_k_cache[0].dtype) debug('decoder_v_cache', decoder_v_cache[0].size(), decoder_v_cache[0].dtype) debug('decoder_cache_seqlens', decoder_cache_seqlens) debug('encoder_k_cache', encoder_cache['encoder_k_cache'][0].size(), encoder_cache['encoder_k_cache'][0].dtype) debug('encoder_v_cache', encoder_cache['encoder_v_cache'][0].size(), encoder_cache['encoder_v_cache'][0].dtype) debug('encoder_cache_seqlens', encoder_cache_seqlens) debug('encoder_cache_batch_idx', encoder_cache_batch_idx) logits = model.decoder.predict( input_ids=input_ids, decoder_k_cache=decoder_k_cache, decoder_v_cache=decoder_v_cache, decoder_cache_seqlens=decoder_cache_seqlens, encoder_k_cache=encoder_cache['encoder_k_cache'], encoder_v_cache=encoder_cache['encoder_v_cache'], encoder_cache_seqlens=encoder_cache_seqlens, encoder_cache_batch_idx=encoder_cache_batch_idx, ) decoder_cache_seqlens += torch.tensor(input_ids_lens, dtype=torch.int32).to(device) input_ids = [] encoder_cache_batch_idx = [] encoder_cache_seqlens = [] remove_cache_batch_idx = [] batch_idx = -1 for doc_idx, document in enumerate(documents): for query in document.queries: if query.output is not None: # This one is done, so it wasn't included in the input_ids continue batch_idx += 1 next_token_logits = logits[batch_idx, input_ids_lens[batch_idx] - 1, :] try: input_ids.append(query.generator.send(next_token_logits)) encoder_cache_batch_idx.append(doc_idx) encoder_cache_seqlens.append(encoder_cache['encoder_cache_seqlens'][doc_idx]) except StopIteration as e: debug('Document', document.meta, 'Query', query.meta, 'Result', e.value) query.output = e.value remove_cache_batch_idx.append(batch_idx) if len(input_ids) == 0: break if len(remove_cache_batch_idx) > 0: debug('Removing cache', remove_cache_batch_idx) cache_mask = torch.ones((decoder_cache_seqlens.size(0),), dtype=torch.bool, device=device) debug('cache_mask', cache_mask.size()) cache_mask[remove_cache_batch_idx] = False decoder_k_cache = [k[cache_mask] for k in decoder_k_cache] decoder_v_cache = [v[cache_mask] for v in decoder_v_cache] decoder_cache_seqlens = decoder_cache_seqlens[cache_mask] input_ids_lens = [len(ids) for ids in input_ids] input_ids_max_seq_len = max(input_ids_lens) input_ids = [ids + [0] * (input_ids_max_seq_len - len(ids)) for ids in input_ids] input_ids = torch.tensor(input_ids, dtype=torch.long).to(device) encoder_cache_batch_idx = torch.tensor(encoder_cache_batch_idx, dtype=torch.int32).to(device) encoder_cache_seqlens = torch.tensor(encoder_cache_seqlens, dtype=torch.int32).to(device) return documents def main(): args = argparse.ArgumentParser() args.add_argument('--model', type=str, required=True) args.add_argument('--tokenizer', type=str, required=True) args.add_argument('--queries', type=str, required=True) args = args.parse_args() from accelerate import Accelerator accelerator = Accelerator() model = Pix2StructModel.load(args.model) model = accelerator.prepare(model) model.eval() tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer) documents = [] for query in json.loads(args.queries): document_pages = [np.array(page) for page in query['document']] document_queries = [ DocumentQuery( meta=question, generator=ask_generator(tokenizer, question), output=None, ) for question in query['questions'] ] documents.append(DocumentQueries( meta=query['document'], patches=extract_patches(document_pages), queries=document_queries, )) with torch.inference_mode(): with accelerator.autocast(): result = generate(model, documents) for document in result: print(f'Document: {document.meta}') for query in document.queries: print(f'Query: {query.meta}') print(f'Answer: {query.output}') print()