ryantong3 commited on
Commit
cc6ac20
·
verified ·
1 Parent(s): 33dd3de

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -25
app.py CHANGED
@@ -1,29 +1,71 @@
1
- from PIL import Image
2
- import requests
3
- from io import BytesIO
4
- from transformers import pipeline
5
- import streamlit as st
6
- def predict(image):
7
- type_food = oracle(image, "What type of food is this?")
8
- cal_est = oracle(image, "About how many calories are in this meal?")
9
- guess1, guess2 = cal_est[0]['answer'], cal_est[1]['answer']
10
- return f"This is {type_food[0]['answer']}. I estimate this to contain {min(guess1, guess2)}-{max(guess1, guess2)} calories"
11
- oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
12
-
13
- def main():
14
- st.title("Image Question Answering App")
15
- st.write("Upload an image and ask a question to get answers!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
18
 
19
- # File uploader for image
20
- uploaded_image = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
21
 
22
- if uploaded_image is not None:
23
- image = Image.open(uploaded_image)
24
- st.image(image, caption="Uploaded Image", use_column_width=True)
25
- response = predict(image)
26
- st.write(response)
27
 
28
- if __name__ == "__main__":
29
- main()
 
1
+ import torch
2
+ import re
3
+ import gradio as gr
4
+ from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
+
6
+ device='cpu'
7
+ encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
8
+ decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
9
+ model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
10
+ feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
11
+ tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
+ model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
+
14
+
15
+ def predict(image,max_length=64, num_beams=3):
16
+ image = image.convert('RGB')
17
+ image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
18
+ clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
19
+ caption_ids = model.generate(image, max_length = max_length)[0]
20
+ caption_text = clean_text(tokenizer.decode(caption_ids))
21
+ return caption_text
22
+
23
+
24
+
25
+ input = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
26
+ output = gr.outputs.Textbox(type="auto",label="Captions")
27
+ examples = [f"example{i}.jpg" for i in range(1,7)]
28
+
29
+ title = "Image Captioning "
30
+ description = "Made by : shreyasdixit.tech"
31
+ interface = gr.Interface(
32
+
33
+ fn=predict,
34
+ description=description,
35
+ inputs = input,
36
+ theme="grass",
37
+ outputs=output,
38
+ examples = examples,
39
+ title=title,
40
+ )
41
+ interface.launch(debug=True)
42
+
43
+ # from PIL import Image
44
+ # import requests
45
+ # from io import BytesIO
46
+ # from transformers import pipeline
47
+ # import streamlit as st
48
+ # def predict(image):
49
+ # type_food = oracle(image, "What type of food is this?")
50
+ # cal_est = oracle(image, "About how many calories are in this meal?")
51
+ # guess1, guess2 = cal_est[0]['answer'], cal_est[1]['answer']
52
+ # return f"This is {type_food[0]['answer']}. I estimate this to contain {min(guess1, guess2)}-{max(guess1, guess2)} calories"
53
+ # oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
54
+
55
+ # def main():
56
+ # st.title("Image Question Answering App")
57
+ # st.write("Upload an image and ask a question to get answers!")
58
 
59
+ # oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
60
 
61
+ # # File uploader for image
62
+ # uploaded_image = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
63
 
64
+ # if uploaded_image is not None:
65
+ # image = Image.open(uploaded_image)
66
+ # st.image(image, caption="Uploaded Image", use_column_width=True)
67
+ # response = predict(image)
68
+ # st.write(response)
69
 
70
+ # if __name__ == "__main__":
71
+ # main()