sajofu commited on
Commit
a87dbf7
·
verified ·
1 Parent(s): 6177680

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -83
app.py CHANGED
@@ -6,122 +6,87 @@ import torch
6
  from transformers import AutoModel, AutoProcessor
7
  from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
8
 
9
- # Set the device for computation
10
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
11
 
12
- # Load the model and processor from Hugging Face
13
- # trust_remote_code=True is necessary for this model
14
  model = AutoModel.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True).to(device)
15
  processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True)
16
 
17
- # Define a custom stopping criteria to stop generation when the model outputs the end-of-text token
18
  class StopOnTokens(StoppingCriteria):
19
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
20
- # The stop token ID for <|endoftext|>
21
- stop_ids = [151645]
22
  for stop_id in stop_ids:
23
- # Check if the last generated token is a stop token
24
  if input_ids[0][-1] == stop_id:
25
  return True
26
  return False
27
 
28
  @torch.no_grad()
29
  def response(message, history, image):
30
- """
31
- This function generates the model's response. It handles both text-only and multimodal inputs,
32
- builds the conversation history, and streams the response back to the UI.
33
- """
34
  stop = StopOnTokens()
35
 
36
- # 1. Build the conversation history
37
  messages = [{"role": "system", "content": "You are a helpful assistant."}]
 
38
  for user_msg, assistant_msg in history:
39
  messages.append({"role": "user", "content": user_msg})
40
- if assistant_msg:
41
- messages.append({"role": "assistant", "content": assistant_msg})
42
 
43
- # 2. Prepare the prompt and model inputs for the current turn
44
- prompt = message
45
- model_kwargs = {}
46
-
47
- # If an image is provided, process it and prepend the <image> token to the prompt
48
- if image is not None:
49
- prompt = f"<image>{message}"
50
- # Process the image using the model's image_processor
51
- processed_images = processor.image_processor(image, return_tensors="pt")['pixel_values'].to(device)
52
- model_kwargs['images'] = processed_images
53
 
54
- messages.append({"role": "user", "content": prompt})
55
 
56
- # 3. Tokenize the conversation using the chat template
57
- inputs = processor.tokenizer.apply_chat_template(
58
  messages,
59
  add_generation_prompt=True,
60
  return_tensors="pt"
61
- ).to(device)
62
- model_kwargs['input_ids'] = inputs
63
-
64
- # 4. Create an attention mask if an image is present
65
- if image is not None:
66
- # The attention mask needs to be manually created to account for the image tokens
67
- attention_mask = torch.ones(
68
- 1, inputs.shape[1] + processor.num_image_latents - 1,
69
- dtype=torch.long,
70
- device=device
71
- )
72
- model_kwargs['attention_mask'] = attention_mask
73
-
74
- # 5. Set up the streamer for text generation
75
- streamer = TextIteratorStreamer(
76
- processor.tokenizer,
77
- timeout=30.,
78
- skip_prompt=True,
79
- skip_special_tokens=True
80
  )
81
 
 
 
 
 
 
 
 
 
 
82
  generate_kwargs = dict(
83
- **model_kwargs,
84
  streamer=streamer,
85
  max_new_tokens=1024,
86
  stopping_criteria=StoppingCriteriaList([stop])
87
- )
88
-
89
- # Run generation in a separate thread to not block the UI
90
  t = Thread(target=model.generate, kwargs=generate_kwargs)
91
  t.start()
92
 
93
- # 6. Stream the response to the Gradio UI
94
- # Append the original user message (without <image> token) to the history for display
95
  history.append([message, ""])
96
  partial_response = ""
97
  for new_token in streamer:
98
- # The model might output this token string instead of the ID
99
- if new_token == '<|endoftext|>':
100
- break
101
  partial_response += new_token
102
  history[-1][1] = partial_response
103
- # Yield updates to the chatbot and buttons
104
- yield history, gr.update(visible=False), gr.update(visible=True, interactive=True)
105
 
106
 
107
- with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
108
- gr.Markdown("# UForm-Gen2 DPO Chat Demo")
109
  with gr.Row():
110
- image = gr.Image(type="pil", label="Upload Image (Optional)")
111
 
112
  with gr.Column():
113
- chat = gr.Chatbot(label="Conversation", show_label=False, elem_id="chatbot")
114
- message = gr.Textbox(
115
- interactive=True,
116
- show_label=False,
117
- placeholder="Type your message or ask about the image...",
118
- container=False
119
- )
120
 
121
  with gr.Row():
122
- gr.ClearButton([chat, message, image], value="🗑️ New Chat")
123
- stop = gr.Button("⏹️ Stop", variant="stop", visible=False)
124
- submit = gr.Button("▶️ Submit", variant="primary")
125
 
126
  with gr.Row():
127
  gr.Examples(
@@ -131,7 +96,7 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
131
  ["images/child.jpg", "Describe the image in one sentence."],
132
  ],
133
  [image, message],
134
- label="Image Captioning Examples"
135
  )
136
  gr.Examples(
137
  [
@@ -140,34 +105,26 @@ with gr.Blocks(theme=gr.themes.Default(primary_hue="blue")) as demo:
140
  ["images/three_people.jpg", "What are these people doing?"]
141
  ],
142
  [image, message],
143
- label="Visual Question Answering (VQA) Examples"
144
  )
145
 
146
- # Define the event handlers for submitting a message
147
  response_handler = (
148
  response,
149
  [message, chat, image],
150
  [chat, submit, stop]
151
  )
152
-
153
- # This handler runs after the generation is complete to reset the button states
154
  postresponse_handler = (
155
- lambda: (gr.update(visible=False), gr.update(visible=True, interactive=True)),
156
  None,
157
- [stop, submit],
158
  )
159
 
160
- # Register the event listeners
161
- # Trigger generation on both text submission (Enter key) and button click
162
  event1 = message.submit(*response_handler)
163
  event1.then(*postresponse_handler)
164
  event2 = submit.click(*response_handler)
165
  event2.then(*postresponse_handler)
166
 
167
- # The stop button cancels the generation events
168
  stop.click(None, None, None, cancels=[event1, event2])
169
 
170
- # Use a queue for smooth streaming and handling multiple users
171
  demo.queue()
172
- # Set share=True to create a public link, necessary for most cloud environments
173
- demo.launch(share=True)
 
6
  from transformers import AutoModel, AutoProcessor
7
  from transformers import StoppingCriteria, TextIteratorStreamer, StoppingCriteriaList
8
 
 
9
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
10
 
 
 
11
  model = AutoModel.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True).to(device)
12
  processor = AutoProcessor.from_pretrained("unum-cloud/uform-gen2-dpo", trust_remote_code=True)
13
 
 
14
  class StopOnTokens(StoppingCriteria):
15
  def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
16
+ stop_ids = [151645]
 
17
  for stop_id in stop_ids:
 
18
  if input_ids[0][-1] == stop_id:
19
  return True
20
  return False
21
 
22
  @torch.no_grad()
23
  def response(message, history, image):
 
 
 
 
24
  stop = StopOnTokens()
25
 
 
26
  messages = [{"role": "system", "content": "You are a helpful assistant."}]
27
+
28
  for user_msg, assistant_msg in history:
29
  messages.append({"role": "user", "content": user_msg})
30
+ messages.append({"role": "assistant", "content": assistant_msg})
 
31
 
32
+ if len(messages) == 1:
33
+ message = f" <image>{message}"
 
 
 
 
 
 
 
 
34
 
35
+ messages.append({"role": "user", "content": message})
36
 
37
+ model_inputs = processor.tokenizer.apply_chat_template(
 
38
  messages,
39
  add_generation_prompt=True,
40
  return_tensors="pt"
41
+ )
42
+
43
+ image = (
44
+ processor.feature_extractor(image)
45
+ .unsqueeze(0)
46
+ )
47
+
48
+ attention_mask = torch.ones(
49
+ 1, model_inputs.shape[1] + processor.num_image_latents - 1
 
 
 
 
 
 
 
 
 
 
50
  )
51
 
52
+ model_inputs = {
53
+ "input_ids": model_inputs,
54
+ "images": image,
55
+ "attention_mask": attention_mask
56
+ }
57
+
58
+ model_inputs = {k: v.to(device) for k, v in model_inputs.items()}
59
+
60
+ streamer = TextIteratorStreamer(processor.tokenizer, timeout=30., skip_prompt=True, skip_special_tokens=True)
61
  generate_kwargs = dict(
62
+ model_inputs,
63
  streamer=streamer,
64
  max_new_tokens=1024,
65
  stopping_criteria=StoppingCriteriaList([stop])
66
+ )
 
 
67
  t = Thread(target=model.generate, kwargs=generate_kwargs)
68
  t.start()
69
 
 
 
70
  history.append([message, ""])
71
  partial_response = ""
72
  for new_token in streamer:
 
 
 
73
  partial_response += new_token
74
  history[-1][1] = partial_response
75
+ yield history, gr.Button(visible=False), gr.Button(visible=True, interactive=True)
 
76
 
77
 
78
+ with gr.Blocks() as demo:
 
79
  with gr.Row():
80
+ image = gr.Image(type="pil")
81
 
82
  with gr.Column():
83
+ chat = gr.Chatbot(show_label=False)
84
+ message = gr.Textbox(interactive=True, show_label=False, container=False)
 
 
 
 
 
85
 
86
  with gr.Row():
87
+ gr.ClearButton([chat, message])
88
+ stop = gr.Button(value="Stop", variant="stop", visible=False)
89
+ submit = gr.Button(value="Submit", variant="primary")
90
 
91
  with gr.Row():
92
  gr.Examples(
 
96
  ["images/child.jpg", "Describe the image in one sentence."],
97
  ],
98
  [image, message],
99
+ label="Captioning"
100
  )
101
  gr.Examples(
102
  [
 
105
  ["images/three_people.jpg", "What are these people doing?"]
106
  ],
107
  [image, message],
108
+ label="VQA"
109
  )
110
 
 
111
  response_handler = (
112
  response,
113
  [message, chat, image],
114
  [chat, submit, stop]
115
  )
 
 
116
  postresponse_handler = (
117
+ lambda: (gr.Button(visible=False), gr.Button(visible=True)),
118
  None,
119
+ [stop, submit]
120
  )
121
 
 
 
122
  event1 = message.submit(*response_handler)
123
  event1.then(*postresponse_handler)
124
  event2 = submit.click(*response_handler)
125
  event2.then(*postresponse_handler)
126
 
 
127
  stop.click(None, None, None, cancels=[event1, event2])
128
 
 
129
  demo.queue()
130
+ demo.launch()