Spaces:
Runtime error
Runtime error
| import argparse | |
| import dataclasses | |
| import json | |
| from typing import Generator, Any | |
| from transformers import T5TokenizerFast | |
| import numpy as np | |
| import torch | |
| from pix2struct.modeling import Pix2StructModel | |
| from pix2struct.processing import extract_patches | |
| def ask_generator(tokenizer, question, max_length=256): | |
| end_token_id = tokenizer.convert_tokens_to_ids(['</s>'])[0] | |
| input_ids = [ | |
| *tokenizer.convert_tokens_to_ids(['<pad>']), | |
| *tokenizer.encode(question, add_special_tokens=False), | |
| *tokenizer.convert_tokens_to_ids(['▁<output>']), | |
| ] | |
| generated_token_ids = [] | |
| too_long = False | |
| while True: | |
| logits = yield input_ids | |
| next_token_id = torch.argmax(logits).item() | |
| if next_token_id == end_token_id: | |
| break | |
| if len(generated_token_ids) >= max_length: | |
| too_long = True | |
| break | |
| generated_token_ids.append(next_token_id) | |
| input_ids = [next_token_id] | |
| if too_long: | |
| return '' | |
| return tokenizer.decode(generated_token_ids) | |
| class DocumentQuery: | |
| meta: Any | |
| generator: Generator | |
| output: Any = None | |
| class DocumentQueries: | |
| meta: Any | |
| patches: torch.Tensor | |
| queries: [DocumentQuery] | |
| def debug(*x): | |
| pass | |
| # print(*x) | |
| def generate( | |
| model: Pix2StructModel, | |
| documents: [DocumentQueries], | |
| device: torch.device, | |
| init_cache_size: int = 512, | |
| ) -> [DocumentQueries]: | |
| documents_patches = [document.patches for document in documents] | |
| documents_patches_lens = [patches.size(0) for patches in documents_patches] | |
| documents_patches = torch.cat(documents_patches, dim=0).to(device) | |
| documents_patches_cu_seq_lens = torch.tensor( | |
| [0, *np.cumsum(documents_patches_lens)], | |
| dtype=torch.int32, | |
| device=device, | |
| ) | |
| documents_patches_max_seq_len = max(documents_patches_lens) | |
| encoder_cache = model.get_encoder_kv_cache( | |
| flattened_patches=documents_patches, | |
| flattened_patches_cu_seq_lens=documents_patches_cu_seq_lens, | |
| flattened_patches_max_seq_len=documents_patches_max_seq_len, | |
| ) | |
| total_queries = sum(len(document.queries) for document in documents) | |
| decoder_k_cache, decoder_v_cache = model.decoder.get_decoder_kv_cache( | |
| device, total_queries, init_cache_size, dtype=torch.bfloat16, | |
| ) | |
| decoder_cache_seqlens = torch.zeros((total_queries,), dtype=torch.int32, device=device) | |
| input_ids = [] | |
| encoder_cache_batch_idx = [] | |
| encoder_cache_seqlens = [] | |
| for doc_idx, document in enumerate(documents): | |
| for query in document.queries: | |
| if query.output is None: | |
| input_ids.append(next(query.generator)) | |
| encoder_cache_batch_idx.append(doc_idx) | |
| encoder_cache_seqlens.append(encoder_cache['encoder_cache_seqlens'][doc_idx]) | |
| input_ids_lens = [len(ids) for ids in input_ids] | |
| input_ids_max_seq_len = max(input_ids_lens) | |
| input_ids = [ids + [0] * (input_ids_max_seq_len - len(ids)) for ids in input_ids] | |
| input_ids = torch.tensor(input_ids, dtype=torch.long).to(device) | |
| encoder_cache_batch_idx = torch.tensor(encoder_cache_batch_idx, dtype=torch.int32).to(device) | |
| encoder_cache_seqlens = torch.tensor(encoder_cache_seqlens, dtype=torch.int32).to(device) | |
| while any(query.output is None for document in documents for query in document.queries): | |
| debug('Generating') | |
| debug('input_ids', input_ids) | |
| debug('input_ids_lens', input_ids_lens) | |
| debug('decoder_k_cache', decoder_k_cache[0].size(), decoder_k_cache[0].dtype) | |
| debug('decoder_v_cache', decoder_v_cache[0].size(), decoder_v_cache[0].dtype) | |
| debug('decoder_cache_seqlens', decoder_cache_seqlens) | |
| debug('encoder_k_cache', encoder_cache['encoder_k_cache'][0].size(), encoder_cache['encoder_k_cache'][0].dtype) | |
| debug('encoder_v_cache', encoder_cache['encoder_v_cache'][0].size(), encoder_cache['encoder_v_cache'][0].dtype) | |
| debug('encoder_cache_seqlens', encoder_cache_seqlens) | |
| debug('encoder_cache_batch_idx', encoder_cache_batch_idx) | |
| logits = model.decoder.predict( | |
| input_ids=input_ids, | |
| decoder_k_cache=decoder_k_cache, | |
| decoder_v_cache=decoder_v_cache, | |
| decoder_cache_seqlens=decoder_cache_seqlens, | |
| encoder_k_cache=encoder_cache['encoder_k_cache'], | |
| encoder_v_cache=encoder_cache['encoder_v_cache'], | |
| encoder_cache_seqlens=encoder_cache_seqlens, | |
| encoder_cache_batch_idx=encoder_cache_batch_idx, | |
| ) | |
| decoder_cache_seqlens += torch.tensor(input_ids_lens, dtype=torch.int32).to(device) | |
| input_ids = [] | |
| encoder_cache_batch_idx = [] | |
| encoder_cache_seqlens = [] | |
| remove_cache_batch_idx = [] | |
| batch_idx = -1 | |
| for doc_idx, document in enumerate(documents): | |
| for query in document.queries: | |
| if query.output is not None: | |
| # This one is done, so it wasn't included in the input_ids | |
| continue | |
| batch_idx += 1 | |
| next_token_logits = logits[batch_idx, input_ids_lens[batch_idx] - 1, :] | |
| try: | |
| input_ids.append(query.generator.send(next_token_logits)) | |
| encoder_cache_batch_idx.append(doc_idx) | |
| encoder_cache_seqlens.append(encoder_cache['encoder_cache_seqlens'][doc_idx]) | |
| except StopIteration as e: | |
| debug('Document', document.meta, 'Query', query.meta, 'Result', e.value) | |
| query.output = e.value | |
| remove_cache_batch_idx.append(batch_idx) | |
| if len(input_ids) == 0: | |
| break | |
| if len(remove_cache_batch_idx) > 0: | |
| debug('Removing cache', remove_cache_batch_idx) | |
| cache_mask = torch.ones((decoder_cache_seqlens.size(0),), dtype=torch.bool, device=device) | |
| debug('cache_mask', cache_mask.size()) | |
| cache_mask[remove_cache_batch_idx] = False | |
| decoder_k_cache = [k[cache_mask] for k in decoder_k_cache] | |
| decoder_v_cache = [v[cache_mask] for v in decoder_v_cache] | |
| decoder_cache_seqlens = decoder_cache_seqlens[cache_mask] | |
| input_ids_lens = [len(ids) for ids in input_ids] | |
| input_ids_max_seq_len = max(input_ids_lens) | |
| input_ids = [ids + [0] * (input_ids_max_seq_len - len(ids)) for ids in input_ids] | |
| input_ids = torch.tensor(input_ids, dtype=torch.long).to(device) | |
| encoder_cache_batch_idx = torch.tensor(encoder_cache_batch_idx, dtype=torch.int32).to(device) | |
| encoder_cache_seqlens = torch.tensor(encoder_cache_seqlens, dtype=torch.int32).to(device) | |
| return documents | |
| def main(): | |
| args = argparse.ArgumentParser() | |
| args.add_argument('--model', type=str, required=True) | |
| args.add_argument('--tokenizer', type=str, required=True) | |
| args.add_argument('--queries', type=str, required=True) | |
| args = args.parse_args() | |
| from accelerate import Accelerator | |
| accelerator = Accelerator() | |
| model = Pix2StructModel.load(args.model) | |
| model = accelerator.prepare(model) | |
| model.eval() | |
| tokenizer = T5TokenizerFast.from_pretrained(args.tokenizer) | |
| documents = [] | |
| for query in json.loads(args.queries): | |
| document_pages = [np.array(page) for page in query['document']] | |
| document_queries = [ | |
| DocumentQuery( | |
| meta=question, | |
| generator=ask_generator(tokenizer, question), | |
| output=None, | |
| ) | |
| for question in query['questions'] | |
| ] | |
| documents.append(DocumentQueries( | |
| meta=query['document'], | |
| patches=extract_patches(document_pages), | |
| queries=document_queries, | |
| )) | |
| with torch.inference_mode(): | |
| with accelerator.autocast(): | |
| result = generate(model, documents) | |
| for document in result: | |
| print(f'Document: {document.meta}') | |
| for query in document.queries: | |
| print(f'Query: {query.meta}') | |
| print(f'Answer: {query.output}') | |
| print() |