sajofu commited on
Commit
9661f49
·
verified ·
1 Parent(s): dbae598

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +102 -114
app.py CHANGED
@@ -1,130 +1,118 @@
1
- import sys
2
- from threading import Thread
3
-
4
  import gradio as gr
5
  import torch
 
6
  from transformers import AutoModel, AutoProcessor
7
- from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
8
 
 
9
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
 
11
- model = AutoModel.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True).to(device)
12
- processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True)
13
-
14
- class StopOnTokens(StoppingCriteria):
15
- def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16
- stop_ids = [151645]
17
- for stop_id in stop_ids:
18
- if input_ids[0][-1] == stop_id:
19
- return True
20
- return False
21
-
22
- @torch.no_grad()
23
- def response(message, history, image):
24
- stop = StopOnTokens()
25
-
26
- messages = [{"role": "system", "content": "You are a helpful assistant."}]
27
-
28
- for user_msg, assistant_msg in history:
29
- messages.append({"role": "user", "content": user_msg})
30
- messages.append({"role": "assistant", "content": assistant_msg})
31
-
32
- if len(messages) == 1:
33
- message = f" <image>{message}"
34
-
35
- messages.append({"role": "user", "content": message})
36
-
37
- model_inputs = processor.tokenizer.apply_chat_template(
38
- messages,
39
- add_generation_prompt=True,
40
  return_tensors="pt"
41
- )
42
-
43
- image = (
44
- processor.feature_extractor(image)
45
- .unsqueeze(0)
46
- )
 
 
 
 
 
47
 
48
- attention_mask = torch.ones(
49
- 1, model_inputs.shape[1] + processor.num_image_latents - 1
50
- )
51
 
52
- model_inputs = {
53
- "input_ids": model_inputs,
54
- "images": image,
55
- "attention_mask": attention_mask
56
- }
57
-
58
- model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
59
-
60
- streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
61
- generate_kwargs = dict(
62
- model_inputs,
63
- streamer=streamer,
64
- max_new_tokens=1024,
65
- stopping_criteria=StoppingCriteriaList([stop])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  )
67
- t = Thread(target=model.generate, kwargs=generate_kwargs)
68
- t.start()
69
-
70
- history.append([message, ""])
71
- partial_response = ""
72
- for new_token in streamer:
73
- partial_response += new_token
74
- history[-1][1] = partial_response
75
- yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True)
76
 
 
 
 
 
 
 
77
 
 
78
  with gr.Blocks() as demo:
79
- with gr.Row():
80
- image = gr.Image(type="pil")
81
-
82
- with gr.Column():
83
- chat = gr.Chatbot(show_label=False)
84
- message = gr.Textbox(interactive=True, show_label=False, container=False)
85
-
86
- with gr.Row():
87
- gr.ClearButton([chat, message])
88
- stop = gr.Button(value="Stop", variant="stop", visible=False)
89
- submit = gr.Button(value="Submit", variant="primary")
90
-
91
- with gr.Row():
92
- gr.Examples(
93
- [
94
- ["images/interior.jpg", "Describe the image accurately."],
95
- ["images/cat.jpg", "Describe the image in three sentences."],
96
- ["images/child.jpg", "Describe the image in one sentence."],
97
- ],
98
- [image, message],
99
- label="Captioning"
100
- )
101
- gr.Examples(
102
- [
103
- ["images/scream.jpg", "What is the main emotion of this image?"],
104
- ["images/louvre.jpg", "Where is this landmark located?"],
105
- ["images/three_people.jpg", "What are these people doing?"]
106
- ],
107
- [image, message],
108
- label="VQA"
109
- )
110
-
111
- response_handler = (
112
- response,
113
- [message, chat, image],
114
- [chat, submit, stop]
115
  )
116
- postresponse_handler = (
117
- lambda: (gr.Button(visible=False), gr.Button(visible=True)),
118
- None,
119
- [stop, submit]
120
  )
121
 
122
- event1 = message.submit(*response_handler)
123
- event1.then(*postresponse_handler)
124
- event2 = submit.click(*response_handler)
125
- event2.then(*postresponse_handler)
126
-
127
- stop.click(None, None, None, cancels=[event1, event2])
128
-
129
- demo.queue()
130
- demo.launch()
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from PIL import Image
4
  from transformers import AutoModel, AutoProcessor
 
5
 
6
+ # Set device
7
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
8
 
9
+ # Load the model and processor
10
+ # We use trust_remote_code=True because this model has custom code.
11
+ model = AutoModel.from_pretrained(
12
+ "unum-cloud/uform-gen2-dpo",
13
+ trust_remote_code=True,
14
+ torch_dtype=torch.bfloat16
15
+ ).to(device)
16
+ processor = AutoProcessor.from_pretrained(
17
+ "unum-cloud/uform-gen2-dpo",
18
+ trust_remote_code=True
19
+ )
20
+
21
+ def transcribe_image(image):
22
+ """
23
+ Generates a caption for the given image.
24
+ """
25
+ if image is None:
26
+ return "Please upload an image."
27
+
28
+ prompt = "a photo of"
29
+ inputs = processor(
30
+ text=[prompt],
31
+ images=[image],
 
 
 
 
 
 
32
  return_tensors="pt"
33
+ ).to(device)
34
+
35
+ with torch.inference_mode():
36
+ output = model.generate(
37
+ **inputs,
38
+ do_sample=False,
39
+ use_cache=True,
40
+ max_new_tokens=128,
41
+ eos_token_id=32001,
42
+ pad_token_id=processor.tokenizer.pad_token_id
43
+ )
44
 
45
+ prompt_len = inputs["input_ids"].shape[1]
46
+ decoded_text = processor.batch_decode(output[:, prompt_len:])[0]
 
47
 
48
+ # Remove the end-of-sequence token
49
+ result = decoded_text.replace("<|im_end|>", "").strip()
50
+ return result
51
+
52
+ def visual_question_answer(image, question):
53
+ """
54
+ Answers a question about the given image.
55
+ """
56
+ if image is None:
57
+ return "Please upload an image."
58
+ if not question:
59
+ return "Please ask a question."
60
+
61
+ # The model expects the prompt to be in a specific format.
62
+ prompt = f"<|im_start|>question\n{question}<|im_end|><|im_start|>answer\n"
63
+
64
+ inputs = processor(
65
+ text=[prompt],
66
+ images=[image],
67
+ return_tensors="pt"
68
+ ).to(device)
69
+
70
+ with torch.inference_mode():
71
+ output = model.generate(
72
+ **inputs,
73
+ do_sample=False,
74
+ use_cache=True,
75
+ max_new_tokens=128,
76
+ eos_token_id=32001,
77
+ pad_token_id=processor.tokenizer.pad_token_id
78
  )
 
 
 
 
 
 
 
 
 
79
 
80
+ prompt_len = inputs["input_ids"].shape[1]
81
+ decoded_text = processor.batch_decode(output[:, prompt_len:])[0]
82
+
83
+ # Remove the end-of-sequence token
84
+ result = decoded_text.replace("<|im_end|>", "").strip()
85
+ return result
86
 
87
+ # Create the Gradio interface
88
  with gr.Blocks() as demo:
89
+ gr.Markdown("# Image Transcription and Visual Question Answering")
90
+ gr.Markdown("Powered by the unum-cloud/uform-gen2-dpo model.")
91
+
92
+ with gr.Tab("Image Transcription"):
93
+ with gr.Row():
94
+ transcribe_image_input = gr.Image(type="pil", label="Upload Image")
95
+ transcribe_output = gr.Textbox(label="Generated Caption")
96
+ transcribe_button = gr.Button("Generate Caption")
97
+
98
+ with gr.Tab("Visual Question Answering"):
99
+ with gr.Row():
100
+ vqa_image_input = gr.Image(type="pil", label="Upload Image")
101
+ with gr.Column():
102
+ vqa_question_input = gr.Textbox(label="Ask a question")
103
+ vqa_output = gr.Textbox(label="Answer")
104
+ vqa_button = gr.Button("Get Answer")
105
+
106
+ # Connect the functions to the Gradio components
107
+ transcribe_button.click(
108
+ fn=transcribe_image,
109
+ inputs=[transcribe_image_input],
110
+ outputs=[transcribe_output]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
+ vqa_button.click(
113
+ fn=visual_question_answer,
114
+ inputs=[vqa_image_input, vqa_question_input],
115
+ outputs=[vqa_output]
116
  )
117
 
118
+ demo.launch()