marc-thibault-h commited on
Commit
b4735a2
·
verified ·
1 Parent(s): 5978474

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +398 -0
app.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ from datetime import datetime
4
+ from typing import Any, Literal
5
+
6
+ import gradio as gr
7
+ import numpy as np
8
+ import requests
9
+ import spaces
10
+ import torch
11
+ from PIL import Image
12
+ from pydantic import BaseModel, Field
13
+ from transformers import AutoProcessor
14
+ from transformers.models.auto.modeling_auto import AutoModelForImageTextToText
15
+ from transformers.models.qwen2_vl.image_processing_qwen2_vl import smart_resize
16
+
17
+ # --- Configuration ---
18
+ MODEL_ID = "Hcompany/Holo1-7B" # TODO update
19
+ # TODO implement model wait?
20
+
21
+ # --- Model and Processor Loading (Load once) ---
22
+ print(f"Loading model and processor for {MODEL_ID}...")
23
+ model = None
24
+ processor = None
25
+ model_loaded = False
26
+ load_error_message = ""
27
+
28
+ # TODO need to install flash-attn like in Holo1?
29
+
30
+
31
+ try:
32
+ model = AutoModelForImageTextToText.from_pretrained(
33
+ MODEL_ID, torch_dtype=torch.bfloat16, trust_remote_code=True
34
+ ).to("cuda")
35
+ processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
36
+
37
+ model_loaded = True
38
+ print("Model and processor loaded successfully.")
39
+ except Exception as e:
40
+ load_error_message = (
41
+ f"Error loading model/processor: {e}\n"
42
+ "This might be due to network issues, an incorrect model ID, or missing dependencies (like flash_attention_2 if enabled by default in some config).\n"
43
+ "Ensure you have a stable internet connection and the necessary libraries installed."
44
+ )
45
+ print(load_error_message)
46
+
47
+
48
+ title = "Holo1.5-7B: Navigation VLM Demo"
49
+
50
+ description = """
51
+ This demo showcases [**Holo1.5-7B**](https://huggingface.co/Hcompany/Holo1.5-7B), a new version of the Action Vision-Language Model developed by HCompany, fine-tuned from Qwen/Qwen2.5-VL-7B-Instruct.
52
+ It's designed to perform complex navigation tasks in Web, Android, and Desktop interfaces.
53
+ **How to use:**
54
+ 1. Upload an image (e.g., a screenshot of a UI, see example below).
55
+ 2. Provide a textual task (e.g., "Book a hotel in Paris on August 3rd for 3 nights").
56
+ 3. The model will predict the next action to take.
57
+ The model processor resizes your input image. Coordinates are relative to this resized image.
58
+ """ # TODO polish
59
+
60
+
61
+ def array_to_image_path(image_array):
62
+ if image_array is None:
63
+ raise ValueError("No image provided. Please upload an image before submitting.")
64
+ # Convert numpy array to PIL Image
65
+ img = Image.fromarray(np.uint8(image_array))
66
+
67
+ # Generate a unique filename using timestamp
68
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
69
+ filename = f"image_{timestamp}.png"
70
+
71
+ # Save the image
72
+ img.save(filename)
73
+
74
+ # Get the full path of the saved image
75
+ full_path = os.path.abspath(filename)
76
+
77
+ return full_path
78
+
79
+
80
+ SYSTEM_PROMPT: str = """Imagine you are a robot browsing the web, just like humans. Now you need to complete a task.
81
+ In each iteration, you will receive an Observation that includes the last screenshots of a web browser and the current memory of the agent.
82
+ You have also information about the step that the agent is trying to achieve to solve the task.
83
+ Carefully analyze the visual information to identify what to do, then follow the guidelines to choose the following action.
84
+ You should detail your thought (i.e. reasoning steps) before taking the action.
85
+ Also detail in the notes field of the action the extracted information relevant to solve the task.
86
+ Once you have enough information in the notes to answer the task, return an answer action with the detailed answer in the notes field.
87
+ This will be evaluated by an evaluator and should match all the criteria or requirements of the task.
88
+ Guidelines:
89
+ - store in the notes all the relevant information to solve the task that fulfill the task criteria. Be precise
90
+ - Use both the task and the step information to decide what to do
91
+ - if you want to write in a text field and the text field already has text, designate the text field by the text it contains and its type
92
+ - If there is a cookies notice, always accept all the cookies first
93
+ - The observation is the screenshot of the current page and the memory of the agent.
94
+ - If you see relevant information on the screenshot to answer the task, add it to the notes field of the action.
95
+ - If there is no relevant information on the screenshot to answer the task, add an empty string to the notes field of the action.
96
+ - If you see buttons that allow to navigate directly to relevant information, like jump to ... or go to ... , use them to navigate faster.
97
+ - In the answer action, give as many details a possible relevant to answering the task.
98
+ - if you want to write, don't click before. Directly use the write action
99
+ - to write, identify the web element which is type and the text it already contains
100
+ - If you want to use a search bar, directly write text in the search bar
101
+ - Don't scroll too much. Don't scroll if the number of scrolls is greater than 3
102
+ - Don't scroll if you are at the end of the webpage
103
+ - Only refresh if you identify a rate limit problem
104
+ - If you are looking for a single flights, click on round-trip to select 'one way'
105
+ - Never try to login, enter email or password. If there is a need to login, then go back.
106
+ - If you are facing a captcha on a website, try to solve it.
107
+ - if you have enough information in the screenshot and in the notes to answer the task, return an answer action with the detailed answer in the notes field
108
+ - The current date is {timestamp}.
109
+ # <output_json_format>
110
+ # ```json
111
+ # {output_format}
112
+ # ```
113
+ # </output_json_format>
114
+ """
115
+
116
+
117
+ class ClickElementAction(BaseModel):
118
+ """Click at absolute coordinates of a web element with its description"""
119
+
120
+ action: Literal["click_element"] = Field(description="Click at absolute coordinates of a web element")
121
+ element: str = Field(description="text description of the element")
122
+ x: int = Field(description="The x coordinate, number of pixels from the left edge.")
123
+ y: int = Field(description="The y coordinate, number of pixels from the top edge.")
124
+
125
+ def log(self):
126
+ return f"I have clicked on the element '{self.element}' at absolute coordinates {self.x}, {self.y}"
127
+
128
+
129
+ class WriteElementAction(BaseModel):
130
+ """Write content at absolute coordinates of a web element identified by its description, then press Enter."""
131
+
132
+ action: Literal["write_element_abs"] = Field(description="Write content at absolute coordinates of a web page")
133
+ content: str = Field(description="Content to write")
134
+ element: str = Field(description="Text description of the element")
135
+ x: int = Field(description="The x coordinate, number of pixels from the left edge.")
136
+ y: int = Field(description="The y coordinate, number of pixels from the top edge.")
137
+
138
+ def log(self):
139
+ return f"I have written '{self.content}' in the element '{self.element}' at absolute coordinates {self.x}, {self.y}"
140
+
141
+
142
+ class ScrollAction(BaseModel):
143
+ """Scroll action with no required element"""
144
+
145
+ action: Literal["scroll"] = Field(description="Scroll the page or a specific element")
146
+ direction: Literal["down", "up", "left", "right"] = Field(description="The direction to scroll in")
147
+
148
+ def log(self):
149
+ return f"I have scrolled {self.direction}"
150
+
151
+
152
+ class GoBackAction(BaseModel):
153
+ """Action to navigate back in browser history"""
154
+
155
+ action: Literal["go_back"] = Field(description="Navigate to the previous page")
156
+
157
+ def log(self):
158
+ return "I have gone back to the previous page"
159
+
160
+
161
+ class RefreshAction(BaseModel):
162
+ """Action to refresh the current page"""
163
+
164
+ action: Literal["refresh"] = Field(description="Refresh the current page")
165
+
166
+ def log(self):
167
+ return "I have refreshed the page"
168
+
169
+
170
+ class GotoAction(BaseModel):
171
+ """Action to go to a particular URL"""
172
+
173
+ action: Literal["goto"] = Field(description="Goto a particular URL")
174
+ url: str = Field(description="A url starting with http:// or https://")
175
+
176
+ def log(self):
177
+ return f"I have navigated to the URL {self.url}"
178
+
179
+
180
+ class WaitAction(BaseModel):
181
+ """Action to wait for a particular amount of time"""
182
+
183
+ action: Literal["wait"] = Field(description="Wait for a particular amount of time")
184
+ seconds: int = Field(default=2, ge=0, le=10, description="The number of seconds to wait")
185
+
186
+ def log(self):
187
+ return f"I have waited for {self.seconds} seconds"
188
+
189
+
190
+ class RestartAction(BaseModel):
191
+ """Restart the task from the beginning."""
192
+
193
+ action: Literal["restart"] = "restart"
194
+
195
+ def log(self):
196
+ return "I have restarted the task from the beginning"
197
+
198
+
199
+ class AnswerAction(BaseModel):
200
+ """Return a final answer to the task. This is the last action to call in an episode."""
201
+
202
+ action: Literal["answer"] = "answer"
203
+ content: str = Field(description="The answer content")
204
+
205
+ def log(self):
206
+ return f"I have answered the task with '{self.content}'"
207
+
208
+
209
+ ActionSpace = (
210
+ ClickElementAction
211
+ | WriteElementAction
212
+ | ScrollAction
213
+ | GoBackAction
214
+ | RefreshAction
215
+ | WaitAction
216
+ | RestartAction
217
+ | AnswerAction
218
+ | GotoAction
219
+ )
220
+
221
+
222
+ class NavigationStep(BaseModel):
223
+ note: str = Field(
224
+ default="",
225
+ description="Task-relevant information extracted from the previous observation. Keep empty if no new info.",
226
+ )
227
+ thought: str = Field(description="Reasoning about next steps (<4 lines)")
228
+ action: ActionSpace = Field(description="Next action to take")
229
+
230
+
231
+ def get_navigation_prompt(task, image, step=1):
232
+ """
233
+ Get the prompt for the navigation task.
234
+ - task: The task to complete
235
+ - image: The current screenshot of the web page
236
+ - step: The current step of the task
237
+ """
238
+ system_prompt = SYSTEM_PROMPT.format(
239
+ output_format=NavigationStep.model_json_schema(),
240
+ timestamp="2025-06-04 14:16:03",
241
+ )
242
+ return [
243
+ {
244
+ "role": "system",
245
+ "content": [
246
+ {"type": "text", "text": system_prompt},
247
+ ],
248
+ },
249
+ {
250
+ "role": "user",
251
+ "content": [
252
+ {"type": "text", "text": f"<task>\n{task}\n</task>\n"},
253
+ {"type": "text", "text": f"<observation step={step}>\n"},
254
+ {"type": "text", "text": "<screenshot>\n"},
255
+ {
256
+ "type": "image",
257
+ "image": image,
258
+ },
259
+ {"type": "text", "text": "\n</screenshot>\n"},
260
+ {"type": "text", "text": "\n</observation>\n"},
261
+ ],
262
+ },
263
+ ]
264
+
265
+
266
+ def array_to_image(image_array: np.ndarray) -> Image.Image:
267
+ if image_array is None:
268
+ raise ValueError("No image provided. Please upload an image before submitting.")
269
+ # Convert numpy array to PIL Image
270
+ img = Image.fromarray(np.uint8(image_array))
271
+ return img
272
+
273
+
274
+ @spaces.GPU(duration=20)
275
+ def run_inference_navigation(messages_for_template: list[dict[str, Any]], pil_image_for_processing: Image.Image) -> str:
276
+ model.to("cuda")
277
+ torch.cuda.set_device(0)
278
+ """
279
+ Runs inference using the Holo1 model.
280
+ - messages_for_template: The prompt structure, potentially including the PIL image object
281
+ (which apply_chat_template converts to an image tag).
282
+ - pil_image_for_processing: The actual PIL image to be processed into tensors.
283
+ """
284
+ # 1. Apply chat template to messages. This will create the text part of the prompt,
285
+ # including image tags if the image was part of `messages_for_template`.
286
+ text_prompt = processor.apply_chat_template(messages_for_template, tokenize=False, add_generation_prompt=True)
287
+
288
+ # 2. Process text and image together to get model inputs
289
+ inputs = processor(
290
+ text=[text_prompt],
291
+ images=[pil_image_for_processing], # Provide the actual image data here
292
+ padding=True,
293
+ return_tensors="pt",
294
+ )
295
+ inputs = inputs.to(model.device)
296
+
297
+ # 3. Generate response
298
+ # Using do_sample=False for more deterministic output, as in the model card's structured output example
299
+ generated_ids = model.generate(**inputs, max_new_tokens=128, do_sample=False)
300
+
301
+ # 4. Trim input_ids from generated_ids to get only the generated part
302
+ generated_ids_trimmed = [out_ids[len(in_ids) :] for in_ids, out_ids in zip(inputs.input_ids, generated_ids)]
303
+
304
+ # 5. Decode the generated tokens
305
+ decoded_output = processor.batch_decode(
306
+ generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False
307
+ )
308
+
309
+ return decoded_output[0] if decoded_output else ""
310
+
311
+
312
+ # --- Gradio processing function ---
313
+ def navigate(input_numpy_image: np.ndarray, task: str) -> str:
314
+ # if not model_loaded or not processor or not model:
315
+ # return f"Model not loaded. Error: {load_error_message}", None
316
+ # if not input_pil_image:
317
+ # return "No image provided. Please upload an image.", None
318
+ # if not task or task.strip() == "":
319
+ # return "No task provided. Please type an task.", input_pil_image.copy().convert("RGB")
320
+
321
+ # 1. Prepare image: Resize according to model's image processor's expected properties
322
+ # This ensures predicted coordinates match the (resized) image dimensions.
323
+ input_pil_image = array_to_image(input_numpy_image)
324
+ assert isinstance(input_pil_image, Image.Image)
325
+ image_proc_config = processor.image_processor
326
+ try:
327
+ resized_height, resized_width = smart_resize(
328
+ input_pil_image.height,
329
+ input_pil_image.width,
330
+ factor=image_proc_config.patch_size * image_proc_config.merge_size,
331
+ min_pixels=image_proc_config.min_pixels,
332
+ max_pixels=image_proc_config.max_pixels,
333
+ )
334
+ # Using LANCZOS for resampling as it's generally good for downscaling.
335
+ # The model card used `resample=None`, which might imply nearest or default.
336
+ # For visual quality in the demo, LANCZOS is reasonable.
337
+ resized_image = input_pil_image.resize(
338
+ size=(resized_width, resized_height),
339
+ resample=Image.Resampling.LANCZOS, # type: ignore
340
+ )
341
+ except Exception as e:
342
+ print(f"Error resizing image: {e}")
343
+ return f"Error resizing image: {e}", input_pil_image.copy().convert("RGB")
344
+
345
+ # 2. Create the prompt using the resized image (for correct image tagging context) and task
346
+ prompt = get_navigation_prompt(task, resized_image, step=1)
347
+
348
+ print("Prompt:")
349
+ print(prompt)
350
+
351
+ # 3. Run inference
352
+ # Pass `messages` (which includes the image object for template processing)
353
+ # and `resized_image` (for actual tensor conversion).
354
+ try:
355
+ navigation_str = run_inference_navigation(prompt, resized_image)
356
+ except Exception as e:
357
+ print(f"Error during model inference: {e}")
358
+ return f"Error during model inference: {e}", resized_image.copy().convert("RGB")
359
+
360
+ return navigation_str
361
+
362
+
363
+ # --- Load Example Data ---
364
+ example_image_url = "https://huggingface.co/Hcompany/Holo1-7B/resolve/main/calendar_example.jpg" # TODO update
365
+ example_image = Image.open(requests.get(example_image_url, stream=True).raw)
366
+ example_task = "Book a hotel in Paris on August 3rd for 3 nights"
367
+
368
+
369
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
370
+ gr.Markdown(f"<h1 style='text-align: center;'>{title}</h1>")
371
+ gr.Markdown(description)
372
+
373
+ with gr.Row():
374
+ with gr.Column():
375
+ input_image_component = gr.Image(label="Input UI Image", height=400)
376
+ # model_selector = gr.Dropdown(choices=list(models.keys()), label="Model", value="Qwen/Qwen2-VL-7B-Instruct") #TODO separate spaces for models?
377
+ task_component = gr.Textbox(
378
+ label="task",
379
+ placeholder="e.g., Book a hotel in Paris on August 3rd for 3 nights",
380
+ info="Type the task you want the model to complete.",
381
+ )
382
+ submit_button = gr.Button("Navigate", variant="primary")
383
+
384
+ with gr.Column():
385
+ output_coords_component = gr.Textbox(label="Navigation Step")
386
+
387
+ submit_button.click(navigate, [input_image_component, task_component], [output_coords_component])
388
+
389
+ gr.Examples(
390
+ examples=[[example_image, example_task]],
391
+ inputs=[input_image_component, task_component],
392
+ outputs=[output_coords_component],
393
+ fn=navigate,
394
+ cache_examples="lazy",
395
+ )
396
+
397
+ demo.queue(api_open=False)
398
+ demo.launch(debug=True)