Daniton commited on
Commit
cf909ce
·
1 Parent(s): 04f02d5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -26
app.py CHANGED
@@ -11,43 +11,29 @@ model, _, transform = open_clip.create_model_and_transforms(
11
  pretrained="laion2b_s13b_b90k"
12
  )
13
  model.to(device)
14
- model.eval() # Ensure the model is in evaluation mode
15
 
16
- def output_generate(image_batch):
 
17
  with torch.no_grad(), torch.cuda.amp.autocast():
18
- generated = model.generate(image_batch.to(device), seq_len=20)
19
- captions = open_clip.decode(generated.detach()).split("<end_of_text>")
20
- captions = [c.replace("<start_of_text>", "") for c in captions]
21
- return captions
22
 
23
- def inference_caption(image_batch):
 
24
  with torch.no_grad(), torch.cuda.amp.autocast():
25
  generated = model.generate(
26
- image_batch.to(device),
27
  generation_type="beam_search",
28
  top_p=1.0,
29
  min_seq_len=20,
30
  seq_len=30,
31
  repetition_penalty=1.2
32
  )
33
- captions = open_clip.decode(generated.detach()).split("<end_of_text>")
34
- captions = [c.replace("<start_of_text>", "") for c in captions]
35
- return captions
36
 
37
- # Use Gradio's batching feature to process multiple images at once
38
- image_input = gr.inputs.Image(type="pil", label="Input Image", max_shape=(224, 224))
39
  caption_output = gr.outputs.Textbox(label="Caption Output")
40
- caption_interface = gr.Interface(
41
- fn=inference_caption,
42
- inputs=gr.inputs.Image(type="pil", label="Input Image", max_shape=(224, 224), multiple=True),
43
- outputs=gr.outputs.Textbox(label="Caption Output", type="auto", multiple=True),
44
- capture_session=True,
45
- title="CoCa: Contrastive Captioners",
46
- description="An open source implementation of CoCa: Contrastive Captioners are Image-Text Foundation Models.",
47
- examples=[path.as_posix() for path in sorted(pathlib.Path("images").glob("*.jpg"))],
48
- allow_flagging=False,
49
- batching=True, # Enable Gradio's batching feature
50
- batch_size=8, # Process 8 images at once
51
- )
52
 
53
- caption_interface.launch()
 
11
  pretrained="laion2b_s13b_b90k"
12
  )
13
  model.to(device)
14
+ model.eval()
15
 
16
+ def output_generate(image):
17
+ im = transform(image).unsqueeze(0).to(device)
18
  with torch.no_grad(), torch.cuda.amp.autocast():
19
+ generated = model.generate(im, seq_len=20)
20
+ return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
 
 
21
 
22
+ def inference_caption(image):
23
+ im = transform(image).unsqueeze(0).to(device)
24
  with torch.no_grad(), torch.cuda.amp.autocast():
25
  generated = model.generate(
26
+ im,
27
  generation_type="beam_search",
28
  top_p=1.0,
29
  min_seq_len=20,
30
  seq_len=30,
31
  repetition_penalty=1.2
32
  )
33
+ return open_clip.decode(generated[0].detach()).split("<end_of_text>")[0].replace("<start_of_text>", "")
 
 
34
 
35
+ image_input = gr.inputs.Image(type="pil")
 
36
  caption_output = gr.outputs.Textbox(label="Caption Output")
37
+ caption_interface = gr.Interface(fn=inference_caption, inputs=image_input, outputs=caption_output, capture_session=True, title="CoCa: Contrastive Captioners", description="An open source implementation of CoCa: Contrastive Captioners are Image-Text Foundation Models.", examples=[path.as_posix() for path in sorted(pathlib.Path("images").glob("*.jpg"))], allow_flagging=False)
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ caption_interface.launch()