Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| Ovis-U1-3B ๅคๆจกๆ DEMO๏ผCPU / GPU ่ช้ๅบ็ๆฌ๏ผ | |
| ไพ่ต๏ผPython 3.10+ใtorch 2.*ใtransformers 4.41.*ใgradio 4.* | |
| """ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # โ ๅจไปปไฝ transformers / flash_attn ๅฏผๅ ฅไนๅๅค็็ฏๅข | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| import os, sys, types, subprocess, random, numpy as np, torch | |
| import importlib.util # โ ๆฐๅข๏ผ็จไบ็ๆ ModuleSpec | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if DEVICE == "cuda" else torch.float32 | |
| # -------- CPU ็ฏๅข๏ผๅฑ่ฝ flash-attn -------- | |
| if DEVICE == "cpu": | |
| # ๅธ่ฝฝๆฝๅจ็ flash-attn | |
| subprocess.run("pip uninstall -y flash-attn", | |
| shell=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| # ๆ้ ็ฉบๅฃณๆจกๅ | |
| fake_flash_attn = types.ModuleType("flash_attn") | |
| fake_layers = types.ModuleType("flash_attn.layers") | |
| fake_rotary = types.ModuleType("flash_attn.layers.rotary") | |
| def _cpu_apply_rotary_emb(x, cos, sin): | |
| """็บฏ CPU ็ๆ่ฝฌไฝ็ฝฎ็ผ็ ๏ผ็ฎๆๅฎ็ฐ๏ผ""" | |
| x1, x2 = x[..., ::2], x[..., 1::2] | |
| rot_x1 = x1 * cos - x2 * sin | |
| rot_x2 = x1 * sin + x2 * cos | |
| out = torch.empty_like(x) | |
| out[..., ::2] = rot_x1 | |
| out[..., 1::2] = rot_x2 | |
| return out | |
| fake_rotary.apply_rotary_emb = _cpu_apply_rotary_emb | |
| fake_layers.rotary = fake_rotary | |
| fake_flash_attn.layers = fake_layers | |
| # โ ๆฐๅข๏ผไธบ็ฉบๅฃณๆจกๅ่กฅๅ ๅๆณ็ __spec__ | |
| fake_flash_attn.__spec__ = importlib.util.spec_from_loader("flash_attn", loader=None) | |
| sys.modules.update({ | |
| "flash_attn": fake_flash_attn, | |
| "flash_attn.layers": fake_layers, | |
| "flash_attn.layers.rotary": fake_rotary, | |
| }) | |
| else: | |
| # GPU ็ฏๅข๏ผๅฐ่ฏๅฎ่ฃ flash-attn | |
| try: | |
| subprocess.run( | |
| "pip install flash-attn==2.6.3 --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, check=True) | |
| except subprocess.CalledProcessError: | |
| print("[WARN] flash-attn ๅฎ่ฃ ๅคฑ่ดฅ๏ผGPU ๅ ้ๅ่ฝๅ้ใ") | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # โก ๅธธ่งไพ่ต | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| from PIL import Image | |
| import gradio as gr | |
| import spaces | |
| from transformers import AutoModelForCausalLM | |
| from test_img_edit import pipe_img_edit | |
| from test_img_to_txt import pipe_txt_gen | |
| from test_txt_to_img import pipe_t2i | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # โข ๅทฅๅ ทๅฝๆฐ & ๅธธ้ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| MAX_SEED = 10_000 | |
| def set_global_seed(seed: int = 42): | |
| random.seed(seed); np.random.seed(seed); torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| def randomize_seed_fn(seed: int, randomize: bool) -> int: | |
| return random.randint(0, MAX_SEED) if randomize else seed | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # โฃ ๅ ่ฝฝๆจกๅ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| MODEL_ID = "AIDC-AI/Ovis-U1-3B" | |
| print(f"[INFO] Loading {MODEL_ID} on {DEVICE} โฆ") | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| low_cpu_mem_usage=True, | |
| device_map="auto", | |
| token=HF_TOKEN, | |
| trust_remote_code=True | |
| ).eval() | |
| print("[INFO] Model ready!") | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # โค ๆจ็ๅฐ่ฃ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def process_txt_to_img(prompt, height, width, steps, seed, cfg, | |
| progress=gr.Progress(track_tqdm=True)): | |
| set_global_seed(seed) | |
| return pipe_t2i(model, prompt, height, width, steps, cfg=cfg, seed=seed) | |
| def process_img_to_txt(prompt, img, progress=gr.Progress(track_tqdm=True)): | |
| return pipe_txt_gen(model, img, prompt) | |
| def process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, | |
| progress=gr.Progress(track_tqdm=True)): | |
| set_global_seed(seed) | |
| return pipe_img_edit(model, img, prompt, steps, txt_cfg, img_cfg, seed=seed) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # โฅ Gradio UI๏ผไธๅ็ไธ่ด๏ผๆญคๅค็็ฅไฟฎๆนๆ ่ฎฐ๏ผ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with gr.Blocks(title="Ovis-U1-3B (CPU/GPU adaptive)") as demo: | |
| gr.Markdown("# Ovis-U1-3B\nๅคๆจกๆๆๆฌ-ๅพๅ DEMO๏ผCPU/GPU ่ช้ๅบ็๏ผ") | |
| with gr.Row(): | |
| with gr.Column(): | |
| with gr.Tabs(): | |
| # Tab 1: Image + Text โ Image | |
| with gr.TabItem("Image + Text โ Image"): | |
| edit_image_input = gr.Image(label="Input Image", type="pil") | |
| with gr.Row(): | |
| edit_prompt_input = gr.Textbox(show_label=False, placeholder="Describe the editing instructionโฆ") | |
| run_edit_image_btn = gr.Button("Run", scale=0) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| edit_img_guidance = gr.Slider(label="Image Guidance", minimum=1, maximum=10, value=1.5, step=0.1) | |
| edit_txt_guidance = gr.Slider(label="Text Guidance", minimum=1, maximum=30, value=6.0, step=0.5) | |
| edit_steps = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1) | |
| edit_seed = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1) | |
| edit_random = gr.Checkbox(label="Randomize seed", value=False) | |
| # Tab 2: Text โ Image | |
| with gr.TabItem("Text โ Image"): | |
| prompt_gen = gr.Textbox(show_label=False, placeholder="Describe the image you wantโฆ") | |
| run_gen_btn = gr.Button("Run", scale=0) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| height_slider = gr.Slider(label="height", minimum=256, maximum=1536, value=1024, step=32) | |
| width_slider = gr.Slider(label="width", minimum=256, maximum=1536, value=1024, step=32) | |
| guidance_slider = gr.Slider(label="Guidance Scale", minimum=1, maximum=30, value=5, step=0.5) | |
| steps_slider = gr.Slider(label="Steps", minimum=40, maximum=100, value=50, step=1) | |
| seed_slider = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, value=42, step=1) | |
| random_check = gr.Checkbox(label="Randomize seed", value=False) | |
| # Tab 3: Image โ Text | |
| with gr.TabItem("Image โ Text"): | |
| understand_img = gr.Image(label="Input Image", type="pil") | |
| understand_prompt = gr.Textbox(show_label=False, placeholder="Describe the question about imageโฆ") | |
| run_understand = gr.Button("Run", scale=0) | |
| clear_btn = gr.Button("Clear All") | |
| with gr.Column(): | |
| gallery = gr.Gallery(label="Generated Images", columns=2, visible=True) | |
| txt_out = gr.Textbox(label="Generated Text", visible=False, lines=5, interactive=False) | |
| # ไบไปถ็ปๅฎ๏ผไธไธไธ็็ธๅ๏ผ็็ฅ้ๅคๆณจ้๏ผ | |
| def run_tab1(prompt, img, steps, seed, txt_cfg, img_cfg, progress=gr.Progress(track_tqdm=True)): | |
| if img is None: | |
| return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True) | |
| imgs = process_img_txt_to_img(prompt, img, steps, seed, txt_cfg, img_cfg, progress) | |
| return gr.update(value=imgs, visible=True), gr.update(value="", visible=False) | |
| def run_tab2(prompt, h, w, steps, seed, guidance, progress=gr.Progress(track_tqdm=True)): | |
| imgs = process_txt_to_img(prompt, h, w, steps, seed, guidance, progress) | |
| return gr.update(value=imgs, visible=True), gr.update(value="", visible=False) | |
| def run_tab3(img, prompt, progress=gr.Progress(track_tqdm=True)): | |
| if img is None: | |
| return gr.update(value=[], visible=False), gr.update(value="Please upload an image.", visible=True) | |
| text = process_img_to_txt(prompt, img, progress) | |
| return gr.update(value=[], visible=False), gr.update(value=text, visible=True) | |
| # Tab1 ็ปๅฎ | |
| run_edit_image_btn.click(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then( | |
| run_tab1, | |
| [edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance], | |
| [gallery, txt_out]) | |
| edit_prompt_input.submit(randomize_seed_fn, [edit_seed, edit_random], [edit_seed]).then( | |
| run_tab1, | |
| [edit_prompt_input, edit_image_input, edit_steps, edit_seed, edit_txt_guidance, edit_img_guidance], | |
| [gallery, txt_out]) | |
| # Tab2 ็ปๅฎ | |
| run_gen_btn.click(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then( | |
| run_tab2, | |
| [prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider], | |
| [gallery, txt_out]) | |
| prompt_gen.submit(randomize_seed_fn, [seed_slider, random_check], [seed_slider]).then( | |
| run_tab2, | |
| [prompt_gen, height_slider, width_slider, steps_slider, seed_slider, guidance_slider], | |
| [gallery, txt_out]) | |
| # Tab3 ็ปๅฎ | |
| run_understand.click(run_tab3, [understand_img, understand_prompt], [gallery, txt_out]) | |
| understand_prompt.submit(run_tab3, [understand_img, understand_prompt], [gallery, txt_out]) | |
| # ๆธ ็ฉบ | |
| def clear_all(): | |
| return ( | |
| gr.update(value=None), gr.update(value=""), gr.update(value=1.5), gr.update(value=6.0), | |
| gr.update(value=50), gr.update(value=42), gr.update(value=False), | |
| gr.update(value=""), gr.update(value=1024), gr.update(value=1024), | |
| gr.update(value=5), gr.update(value=50), gr.update(value=42), gr.update(value=False), | |
| gr.update(value=None), gr.update(value=""), | |
| gr.update(value=[], visible=True), gr.update(value="", visible=False) | |
| ) | |
| clear_btn.click(clear_all, [], [ | |
| edit_image_input, edit_prompt_input, edit_img_guidance, edit_txt_guidance, | |
| edit_steps, edit_seed, edit_random, prompt_gen, height_slider, width_slider, | |
| guidance_slider, steps_slider, seed_slider, random_check, understand_img, | |
| understand_prompt, gallery, txt_out | |
| ]) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # โฆ ๅฏๅจ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| if __name__ == "__main__": | |
| demo.launch() | |