Daniton commited on
Commit
5c90ea8
·
1 Parent(s): 2b26d59

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -2
app.py CHANGED
@@ -11,17 +11,19 @@ model, _, transform = open_clip.create_model_and_transforms(
11
  pretrained="laion2b_s13b_b90k"
12
  )
13
  model.to(device)
 
 
14
 
15
  def output_generate(image):
16
  im = transform(image).unsqueeze(0).to(device)
17
  with torch.no_grad(), torch.cuda.amp.autocast():
18
- generated = model.generate(im, seq_len=20)
19
  return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
20
 
21
  def inference_caption(image):
22
  im = transform(image).unsqueeze(0).to(device)
23
  with torch.no_grad(), torch.cuda.amp.autocast():
24
- generated = model.generate(
25
  im,
26
  generation_type="beam_search",
27
  top_p=1.0,
@@ -31,6 +33,8 @@ def inference_caption(image):
31
  )
32
  return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
33
 
 
 
34
  image_input = gr.inputs.Image(type="pil")
35
  caption_output = gr.outputs.Textbox(label="Caption Output")
36
  caption_interface = gr.Interface(fn=inference_caption, inputs=image_input, outputs=caption_output, capture_session=True, title="CoCa: Contrastive Captioners", description="An open source implementation of CoCa: Contrastive Captioners are Image-Text Foundation Models.", examples=[path.as_posix() for path in sorted(pathlib.Path("images").glob("*.jpg"))], allow_flagging=False)
 
11
  pretrained="laion2b_s13b_b90k"
12
  )
13
  model.to(device)
14
+ model.eval()
15
+ traced_model = torch.jit.trace(model, torch.zeros((1, 3, 64, 64)).to(device))
16
 
17
  def output_generate(image):
18
  im = transform(image).unsqueeze(0).to(device)
19
  with torch.no_grad(), torch.cuda.amp.autocast():
20
+ generated = traced_model.generate(im, seq_len=20)
21
  return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
22
 
23
  def inference_caption(image):
24
  im = transform(image).unsqueeze(0).to(device)
25
  with torch.no_grad(), torch.cuda.amp.autocast():
26
+ generated = traced_model.generate(
27
  im,
28
  generation_type="beam_search",
29
  top_p=1.0,
 
33
  )
34
  return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
35
 
36
+ transform = open_clip.get_transforms("coca_ViT-B-32", image_size=128)
37
+
38
  image_input = gr.inputs.Image(type="pil")
39
  caption_output = gr.outputs.Textbox(label="Caption Output")
40
  caption_interface = gr.Interface(fn=inference_caption, inputs=image_input, outputs=caption_output, capture_session=True, title="CoCa: Contrastive Captioners", description="An open source implementation of CoCa: Contrastive Captioners are Image-Text Foundation Models.", examples=[path.as_posix() for path in sorted(pathlib.Path("images").glob("*.jpg"))], allow_flagging=False)