Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -15,18 +15,21 @@ from PIL import Image
|
|
| 15 |
|
| 16 |
from transformers.image_transforms import resize, to_channel_dimension_format
|
| 17 |
|
| 18 |
-
|
| 19 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 20 |
|
| 21 |
-
|
|
|
|
| 22 |
PROCESSOR = AutoProcessor.from_pretrained(
|
| 23 |
"HuggingFaceM4/VLM_WebSight_finetuned",
|
| 24 |
)
|
| 25 |
MODEL = AutoModelForCausalLM.from_pretrained(
|
| 26 |
"HuggingFaceM4/VLM_WebSight_finetuned",
|
| 27 |
trust_remote_code=True,
|
| 28 |
-
torch_dtype=torch.bfloat16,
|
| 29 |
).to(DEVICE)
|
|
|
|
|
|
|
| 30 |
if MODEL.config.use_resampler:
|
| 31 |
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
|
| 32 |
else:
|
|
@@ -36,12 +39,9 @@ else:
|
|
| 36 |
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
|
| 37 |
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
|
| 38 |
|
| 39 |
-
|
| 40 |
## Utils
|
| 41 |
|
| 42 |
def convert_to_rgb(image):
|
| 43 |
-
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
|
| 44 |
-
# for transparent images. The call to `alpha_composite` handles this case
|
| 45 |
if image.mode == "RGB":
|
| 46 |
return image
|
| 47 |
|
|
@@ -51,8 +51,6 @@ def convert_to_rgb(image):
|
|
| 51 |
alpha_composite = alpha_composite.convert("RGB")
|
| 52 |
return alpha_composite
|
| 53 |
|
| 54 |
-
# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
|
| 55 |
-
# so this is a hack in order to redefine ONLY the transform method
|
| 56 |
def custom_transform(x):
|
| 57 |
x = convert_to_rgb(x)
|
| 58 |
x = to_numpy_array(x)
|
|
@@ -69,13 +67,7 @@ def custom_transform(x):
|
|
| 69 |
|
| 70 |
## End of Utils
|
| 71 |
|
| 72 |
-
|
| 73 |
-
IMAGE_GALLERY_PATHS = [
|
| 74 |
-
f"example_images/{ex_image}"
|
| 75 |
-
for ex_image in os.listdir(f"example_images")
|
| 76 |
-
]
|
| 77 |
-
|
| 78 |
-
|
| 79 |
def install_playwright():
|
| 80 |
try:
|
| 81 |
subprocess.run(["playwright", "install"], check=True)
|
|
@@ -85,17 +77,15 @@ def install_playwright():
|
|
| 85 |
|
| 86 |
install_playwright()
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
-
def add_file_gallery(
|
| 90 |
-
selected_state: gr.SelectData,
|
| 91 |
-
gallery_list: List[str]
|
| 92 |
-
):
|
| 93 |
return Image.open(gallery_list.root[selected_state.index].image.path)
|
| 94 |
|
| 95 |
-
|
| 96 |
-
def render_webpage(
|
| 97 |
-
html_css_code,
|
| 98 |
-
):
|
| 99 |
with sync_playwright() as p:
|
| 100 |
browser = p.chromium.launch(headless=True)
|
| 101 |
context = browser.new_context(
|
|
@@ -115,11 +105,8 @@ def render_webpage(
|
|
| 115 |
|
| 116 |
return Image.open(output_path_screenshot)
|
| 117 |
|
| 118 |
-
|
| 119 |
@spaces.GPU(duration=180)
|
| 120 |
-
def model_inference(
|
| 121 |
-
image,
|
| 122 |
-
):
|
| 123 |
if image is None:
|
| 124 |
raise ValueError("`image` is None. It should be a PIL image.")
|
| 125 |
|
|
@@ -132,10 +119,7 @@ def model_inference(
|
|
| 132 |
[image],
|
| 133 |
transform=custom_transform
|
| 134 |
)
|
| 135 |
-
inputs = {
|
| 136 |
-
k: v.to(DEVICE)
|
| 137 |
-
for k, v in inputs.items()
|
| 138 |
-
}
|
| 139 |
|
| 140 |
streamer = TextIteratorStreamer(
|
| 141 |
PROCESSOR.tokenizer,
|
|
@@ -147,16 +131,6 @@ def model_inference(
|
|
| 147 |
max_length=4096,
|
| 148 |
streamer=streamer,
|
| 149 |
)
|
| 150 |
-
# Regular generation version
|
| 151 |
-
# generation_kwargs.pop("streamer")
|
| 152 |
-
# generated_ids = MODEL.generate(**generation_kwargs)
|
| 153 |
-
# generated_text = PROCESSOR.batch_decode(
|
| 154 |
-
# generated_ids,
|
| 155 |
-
# skip_special_tokens=True
|
| 156 |
-
# )[0]
|
| 157 |
-
# rendered_page = render_webpage(generated_text)
|
| 158 |
-
# return generated_text, rendered_page
|
| 159 |
-
# Token streaming version
|
| 160 |
thread = Thread(
|
| 161 |
target=MODEL.generate,
|
| 162 |
kwargs=generation_kwargs,
|
|
@@ -172,20 +146,8 @@ def model_inference(
|
|
| 172 |
generated_text += new_text
|
| 173 |
yield generated_text, rendered_image
|
| 174 |
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
label="Extracted HTML",
|
| 178 |
-
elem_id="generated_html",
|
| 179 |
-
)
|
| 180 |
-
rendered_html = gr.Image(
|
| 181 |
-
label="Rendered HTML",
|
| 182 |
-
show_download_button=False,
|
| 183 |
-
show_share_button=False,
|
| 184 |
-
)
|
| 185 |
-
# rendered_html = gr.HTML(
|
| 186 |
-
# label="Rendered HTML"
|
| 187 |
-
# )
|
| 188 |
-
|
| 189 |
|
| 190 |
css = """
|
| 191 |
.gradio-container{max-width: 1000px!important}
|
|
@@ -193,7 +155,6 @@ h1{display: flex;align-items: center;justify-content: center;gap: .25em}
|
|
| 193 |
*{transition: width 0.5s ease, flex-grow 0.5s ease}
|
| 194 |
"""
|
| 195 |
|
| 196 |
-
|
| 197 |
with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
|
| 198 |
gr.Markdown(
|
| 199 |
"Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
|
|
@@ -208,15 +169,11 @@ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as d
|
|
| 208 |
)
|
| 209 |
with gr.Group():
|
| 210 |
with gr.Row():
|
| 211 |
-
submit_btn = gr.Button(
|
| 212 |
-
value="▶️ Submit", visible=True, min_width=120
|
| 213 |
-
)
|
| 214 |
clear_btn = gr.ClearButton(
|
| 215 |
[imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
|
| 216 |
)
|
| 217 |
-
regenerate_btn = gr.Button(
|
| 218 |
-
value="🔄 Regenerate", visible=True, min_width=120
|
| 219 |
-
)
|
| 220 |
with gr.Column(scale=4):
|
| 221 |
rendered_html.render()
|
| 222 |
|
|
@@ -235,11 +192,7 @@ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as d
|
|
| 235 |
)
|
| 236 |
|
| 237 |
gr.on(
|
| 238 |
-
triggers=[
|
| 239 |
-
imagebox.upload,
|
| 240 |
-
submit_btn.click,
|
| 241 |
-
regenerate_btn.click,
|
| 242 |
-
],
|
| 243 |
fn=model_inference,
|
| 244 |
inputs=[imagebox],
|
| 245 |
outputs=[generated_html, rendered_html],
|
|
|
|
| 15 |
|
| 16 |
from transformers.image_transforms import resize, to_channel_dimension_format
|
| 17 |
|
| 18 |
+
# Install flash-attn without CUDA build isolation
|
| 19 |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
| 20 |
|
| 21 |
+
# Set the device to GPU if available, otherwise use CPU
|
| 22 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
PROCESSOR = AutoProcessor.from_pretrained(
|
| 24 |
"HuggingFaceM4/VLM_WebSight_finetuned",
|
| 25 |
)
|
| 26 |
MODEL = AutoModelForCausalLM.from_pretrained(
|
| 27 |
"HuggingFaceM4/VLM_WebSight_finetuned",
|
| 28 |
trust_remote_code=True,
|
| 29 |
+
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
|
| 30 |
).to(DEVICE)
|
| 31 |
+
|
| 32 |
+
# Determine image sequence length
|
| 33 |
if MODEL.config.use_resampler:
|
| 34 |
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
|
| 35 |
else:
|
|
|
|
| 39 |
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
|
| 40 |
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
|
| 41 |
|
|
|
|
| 42 |
## Utils
|
| 43 |
|
| 44 |
def convert_to_rgb(image):
|
|
|
|
|
|
|
| 45 |
if image.mode == "RGB":
|
| 46 |
return image
|
| 47 |
|
|
|
|
| 51 |
alpha_composite = alpha_composite.convert("RGB")
|
| 52 |
return alpha_composite
|
| 53 |
|
|
|
|
|
|
|
| 54 |
def custom_transform(x):
|
| 55 |
x = convert_to_rgb(x)
|
| 56 |
x = to_numpy_array(x)
|
|
|
|
| 67 |
|
| 68 |
## End of Utils
|
| 69 |
|
| 70 |
+
# Install Playwright
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
def install_playwright():
|
| 72 |
try:
|
| 73 |
subprocess.run(["playwright", "install"], check=True)
|
|
|
|
| 77 |
|
| 78 |
install_playwright()
|
| 79 |
|
| 80 |
+
IMAGE_GALLERY_PATHS = [
|
| 81 |
+
f"example_images/{ex_image}"
|
| 82 |
+
for ex_image in os.listdir(f"example_images")
|
| 83 |
+
]
|
| 84 |
|
| 85 |
+
def add_file_gallery(selected_state: gr.SelectData, gallery_list: List[str]):
|
|
|
|
|
|
|
|
|
|
| 86 |
return Image.open(gallery_list.root[selected_state.index].image.path)
|
| 87 |
|
| 88 |
+
def render_webpage(html_css_code):
|
|
|
|
|
|
|
|
|
|
| 89 |
with sync_playwright() as p:
|
| 90 |
browser = p.chromium.launch(headless=True)
|
| 91 |
context = browser.new_context(
|
|
|
|
| 105 |
|
| 106 |
return Image.open(output_path_screenshot)
|
| 107 |
|
|
|
|
| 108 |
@spaces.GPU(duration=180)
|
| 109 |
+
def model_inference(image):
|
|
|
|
|
|
|
| 110 |
if image is None:
|
| 111 |
raise ValueError("`image` is None. It should be a PIL image.")
|
| 112 |
|
|
|
|
| 119 |
[image],
|
| 120 |
transform=custom_transform
|
| 121 |
)
|
| 122 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
streamer = TextIteratorStreamer(
|
| 125 |
PROCESSOR.tokenizer,
|
|
|
|
| 131 |
max_length=4096,
|
| 132 |
streamer=streamer,
|
| 133 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
thread = Thread(
|
| 135 |
target=MODEL.generate,
|
| 136 |
kwargs=generation_kwargs,
|
|
|
|
| 146 |
generated_text += new_text
|
| 147 |
yield generated_text, rendered_image
|
| 148 |
|
| 149 |
+
generated_html = gr.Code(label="Extracted HTML", elem_id="generated_html")
|
| 150 |
+
rendered_html = gr.Image(label="Rendered HTML", show_download_button=False, show_share_button=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 151 |
|
| 152 |
css = """
|
| 153 |
.gradio-container{max-width: 1000px!important}
|
|
|
|
| 155 |
*{transition: width 0.5s ease, flex-grow 0.5s ease}
|
| 156 |
"""
|
| 157 |
|
|
|
|
| 158 |
with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo:
|
| 159 |
gr.Markdown(
|
| 160 |
"Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content."
|
|
|
|
| 169 |
)
|
| 170 |
with gr.Group():
|
| 171 |
with gr.Row():
|
| 172 |
+
submit_btn = gr.Button(value="▶️ Submit", visible=True, min_width=120)
|
|
|
|
|
|
|
| 173 |
clear_btn = gr.ClearButton(
|
| 174 |
[imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120
|
| 175 |
)
|
| 176 |
+
regenerate_btn = gr.Button(value="🔄 Regenerate", visible=True, min_width=120)
|
|
|
|
|
|
|
| 177 |
with gr.Column(scale=4):
|
| 178 |
rendered_html.render()
|
| 179 |
|
|
|
|
| 192 |
)
|
| 193 |
|
| 194 |
gr.on(
|
| 195 |
+
triggers=[imagebox.upload, submit_btn.click, regenerate_btn.click],
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
fn=model_inference,
|
| 197 |
inputs=[imagebox],
|
| 198 |
outputs=[generated_html, rendered_html],
|