Vasudevakrishna commited on
Commit
6cda9a1
·
verified ·
1 Parent(s): 110f44d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer
3
+ from model import CustomClipPhi2
4
+
5
+ clip_model_name = "openai/clip-vit-base-patch32"
6
+ phi_model_name = "microsoft/phi-2"
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(phi_model_name, trust_remote_code=True)
9
+ tokenizer.pad_token = tokenizer.eos_token
10
+
11
+ IMAGE_TOKEN_ID = 23903 # token for word Comments
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ max_tokens = 30
14
+ model = CustomClipPhi2(tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560)
15
+
16
+ def generate(images):
17
+ clip_outputs = model.clip_model(**images)
18
+ # remove cls token
19
+ images = clip_outputs.last_hidden_state[:, 1:, :]
20
+ image_embeddings = model.projection_layer(images).to(torch.float16)
21
+
22
+ batch_size = images.size()[0]
23
+ predicted_caption = torch.full((batch_size, max_tokens), model.EOS_TOKEN_ID, dtype=torch.long, device=device)
24
+ img_token_tensor = torch.tensor(IMAGE_TOKEN_ID).repeat(batch_size, 1)
25
+ img_token_embeds = model.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
26
+ combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1)
27
+
28
+ for pos in range(max_tokens - 1):
29
+ model_output_logits = model.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
30
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
31
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
32
+ predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
33
+ next_token_embeds = model.phi2_model.model.embed_tokens(predicted_word_token)
34
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
35
+ return predicted_caption
36
+
37
+ # Create a Gradio interface
38
+ iface = gr.Interface(
39
+ fn=generate, # Function to be called on user input
40
+ inputs=gr.Image(
41
+ width=416, height=416,
42
+ type="pil", image_mode='RGB', label="Upload Image"
43
+ ),
44
+ outputs=gr.Textbox(
45
+ label="Response from AI Model: ",
46
+ ),
47
+ examples = ['car.jpg']
48
+ )
49
+
50
+ # Launch the Gradio app
51
+ iface.launch()