ryantong3 commited on
Commit
92db232
·
verified ·
1 Parent(s): cc6ac20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -67
app.py CHANGED
@@ -1,71 +1,30 @@
1
- import torch
2
- import re
3
- import gradio as gr
4
- from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel
5
-
6
- device='cpu'
7
- encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
8
- decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
9
- model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
10
- feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
11
- tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
12
- model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
13
-
14
-
15
- def predict(image,max_length=64, num_beams=3):
16
- image = image.convert('RGB')
17
- image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
18
- clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
19
- caption_ids = model.generate(image, max_length = max_length)[0]
20
- caption_text = clean_text(tokenizer.decode(caption_ids))
21
- return caption_text
22
-
23
-
24
-
25
- input = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
26
- output = gr.outputs.Textbox(type="auto",label="Captions")
27
- examples = [f"example{i}.jpg" for i in range(1,7)]
28
-
29
- title = "Image Captioning "
30
- description = "Made by : shreyasdixit.tech"
31
- interface = gr.Interface(
32
-
33
- fn=predict,
34
- description=description,
35
- inputs = input,
36
- theme="grass",
37
- outputs=output,
38
- examples = examples,
39
- title=title,
40
- )
41
- interface.launch(debug=True)
42
-
43
- # from PIL import Image
44
- # import requests
45
- # from io import BytesIO
46
- # from transformers import pipeline
47
- # import streamlit as st
48
- # def predict(image):
49
- # type_food = oracle(image, "What type of food is this?")
50
- # cal_est = oracle(image, "About how many calories are in this meal?")
51
- # guess1, guess2 = cal_est[0]['answer'], cal_est[1]['answer']
52
- # return f"This is {type_food[0]['answer']}. I estimate this to contain {min(guess1, guess2)}-{max(guess1, guess2)} calories"
53
- # oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
54
-
55
- # def main():
56
- # st.title("Image Question Answering App")
57
- # st.write("Upload an image and ask a question to get answers!")
58
 
59
- # oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
60
 
61
- # # File uploader for image
62
- # uploaded_image = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
63
 
64
- # if uploaded_image is not None:
65
- # image = Image.open(uploaded_image)
66
- # st.image(image, caption="Uploaded Image", use_column_width=True)
67
- # response = predict(image)
68
- # st.write(response)
69
 
70
- # if __name__ == "__main__":
71
- # main()
 
1
+ from PIL import Image
2
+ import requests
3
+ from io import BytesIO
4
+ from transformers import pipeline
5
+ import streamlit as st
6
+
7
+ def predict(image):
8
+ type_food = oracle(image, "What type of food is this?")
9
+ cal_est = oracle(image, "About how many calories are in this meal?")
10
+ guess1, guess2 = cal_est[0]['answer'], cal_est[1]['answer']
11
+ return f"This is {type_food[0]['answer']}. I estimate this to contain {min(guess1, guess2)}-{max(guess1, guess2)} calories"
12
+ oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
13
+
14
+ def main():
15
+ st.title("Image Question Answering App")
16
+ st.write("Upload an image and ask a question to get answers!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ oracle = pipeline(model="dandelin/vilt-b32-finetuned-vqa")
19
 
20
+ # File uploader for image
21
+ uploaded_image = st.file_uploader("Upload Image", type=["jpg", "jpeg", "png"])
22
 
23
+ if uploaded_image is not None:
24
+ image = Image.open(uploaded_image)
25
+ st.image(image, caption="Uploaded Image", use_column_width=True)
26
+ response = predict(image)
27
+ st.write(response)
28
 
29
+ if __name__ == "__main__":
30
+ main()