| | import torch |
| |
|
| | def generate_text(model, image, tokenizer, image_transfrom, max_length=30): |
| | |
| |
|
| | model.eval() |
| | |
| |
|
| | temperature = 0.9 |
| | stop_token_id = tokenizer.pad_token_id |
| | output_ids = [] |
| |
|
| |
|
| | image = image_transfrom(image) |
| | img_tensor = image.unsqueeze(0) |
| | images_embedding = model.clip(img_tensor) |
| |
|
| | images_projection = model.mapping_network(images_embedding).view(-1, model.max_length, model.gpt_embedding_size) |
| | |
| | input_state = images_projection |
| |
|
| | with torch.no_grad(): |
| | for i in range(max_length): |
| | outputs = model.gpt(input_state, None).logits |
| |
|
| | next_token_scores = outputs[0, -1, :].detach().div(temperature).softmax(dim=0) |
| |
|
| | |
| | next_token_id = next_token_scores.max(dim=0).indices.item() |
| |
|
| | if next_token_id == stop_token_id: |
| | break |
| |
|
| | output_ids.append(next_token_id) |
| |
|
| | |
| | |
| | next_token_id = torch.tensor([next_token_id]).unsqueeze(0) |
| | next_token_embed = model.gpt.base_network.transformer.wte(next_token_id) |
| | input_state = torch.cat((input_state, next_token_embed), dim=1) |
| | |
| | return tokenizer.decode(output_ids) |