Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import torch | |
| def prepare_qa_input(sample, num_captions, num_captions_fid): | |
| sample_question_captions = [] | |
| for question, captions in zip(sample['text_input'], sample['captions']): | |
| assert isinstance(captions, list) | |
| question_captions = [] | |
| question_caption = '' | |
| for cap_id, cap_ in enumerate(captions[0:num_captions]): | |
| question_caption += (cap_.strip() + '. ') | |
| if (cap_id + 1) != num_captions and ((cap_id + 1) % num_captions_fid == 0): | |
| question_caption = question.lower().strip() + " \\n " + question_caption.lower().strip() | |
| question_captions.append(question_caption) | |
| question_caption = '' | |
| if (cap_id + 1) == num_captions: | |
| question_caption = question.lower().strip() + " \\n " + question_caption.lower().strip() | |
| question_captions.append(question_caption) | |
| sample_question_captions.append(question_captions) | |
| sample['question_captions'] = sample_question_captions | |