JerryAnto commited on
Commit
97f080a
·
1 Parent(s): c5dbe05

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +31 -17
app.py CHANGED
@@ -30,26 +30,31 @@ model.to(device)
30
 
31
 
32
 
 
 
 
 
 
 
33
  max_length = 16
34
  num_beams = 4
35
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
36
- def predict_step(image_paths):
37
- images = []
38
- for image_path in image_paths:
39
- i_image = Image.open(image_path)
40
  if i_image.mode != "RGB":
41
- i_image = i_image.convert(mode="RGB")
42
 
43
- images.append(i_image)
 
44
 
45
- pixel_values = feature_extractor(images=images, return_tensors="pt").pixel_values
46
- pixel_values = pixel_values.to(device)
47
 
48
- output_ids = model.generate(pixel_values, **gen_kwargs)
 
 
 
 
49
 
50
- preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
51
- preds = [pred.strip() for pred in preds]
52
- return preds
53
 
54
  #predict_step(['/content/drive/MyDrive/caption generator/horses.png'])
55
 
@@ -64,20 +69,29 @@ inputs = [
64
  gr.inputs.Image(type="pil", label="Original Image")
65
  ]
66
 
 
 
 
 
 
 
 
67
  outputs = [
68
  gr.outputs.Textbox(label = 'Caption')
69
  ]
70
 
71
- title = "Image Captioning using ViT + GPT2"
72
- description = "ViT and GPT2 are used to generate Image Caption for the uploaded image. COCO Dataset was used for training. This image captioning model might have some biases that we couldn't figure during our stress testing, so if you find any bias (gender, race and so on) please use `Flag` button to flag the image with bias"
73
- article = " <a href='https://huggingface.co/sachin/vit2distilgpt2'>Model Repo on Hugging Face Model Hub</a>"
74
  examples = [
75
  ["horses.png"],
76
- ["persons.jpeg"],
77
- ['football_player']
78
 
79
  ]
80
 
 
 
81
  gr.Interface(
82
  predict_step,
83
  inputs,
 
30
 
31
 
32
 
33
+
34
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ model.to(device)
36
+
37
+
38
+
39
  max_length = 16
40
  num_beams = 4
41
  gen_kwargs = {"max_length": max_length, "num_beams": num_beams}
42
+ def predict_step1(image_paths):
43
+ i_image = PIL.Image.open(image_paths)
 
 
44
  if i_image.mode != "RGB":
45
+ i_image = i_image.convert(mode="RGB")
46
 
47
+ pixel_values = feature_extractor(images=i_image, return_tensors="pt").pixel_values
48
+ pixel_values = pixel_values.to(device)
49
 
50
+ output_ids = model.generate(pixel_values, **gen_kwargs)
 
51
 
52
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
53
+ preds = [pred.strip() for pred in preds]
54
+ return preds
55
+
56
+
57
 
 
 
 
58
 
59
  #predict_step(['/content/drive/MyDrive/caption generator/horses.png'])
60
 
 
69
  gr.inputs.Image(type="pil", label="Original Image")
70
  ]
71
 
72
+ outputs = [
73
+ import gradio as gr
74
+
75
+ inputs = [
76
+ gr.inputs.Image(type="filepath", label="Original Image")
77
+ ]
78
+
79
  outputs = [
80
  gr.outputs.Textbox(label = 'Caption')
81
  ]
82
 
83
+ title = "Image Captioning"
84
+ description = "ViT and GPT2 are used to generate Image Caption for the uploaded image."
85
+ article = " <a href='https://huggingface.co/nlpconnect/vit-gpt2-image-captioning'>Model Repo on Hugging Face Model Hub</a>"
86
  examples = [
87
  ["horses.png"],
88
+ ['persons.png'],
89
+ ['football_player.png']
90
 
91
  ]
92
 
93
+
94
+
95
  gr.Interface(
96
  predict_step,
97
  inputs,