|
|
import gradio as gr |
|
|
import torch |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import spaces |
|
|
import math |
|
|
import re |
|
|
from einops import rearrange |
|
|
from mmengine.config import Config |
|
|
from xtuner.registry import BUILDER |
|
|
from xtuner.model.utils import guess_load_checkpoint |
|
|
|
|
|
|
|
|
import matplotlib |
|
|
matplotlib.use("Agg") |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from scripts.camera.cam_dataset import Cam_Generator |
|
|
from scripts.camera.visualization.visualize_batch import make_perspective_figures |
|
|
|
|
|
|
|
|
NUM = r"[+-]?(?:\d+(?:\.\d+)?|\.\d+)(?:[eE][+-]?\d+)?" |
|
|
CAM_PATTERN = re.compile(r"(?:camera parameters.*?:|roll.*?:)\s*("+NUM+r")\s*,\s*("+NUM+r")\s*,\s*("+NUM+r")", re.IGNORECASE|re.DOTALL) |
|
|
|
|
|
def center_crop(image): |
|
|
w, h = image.size |
|
|
s = min(w, h) |
|
|
l = (w - s) // 2 |
|
|
t = (h - s) // 2 |
|
|
return image.crop((l, t, l + s, t + s)) |
|
|
|
|
|
|
|
|
|
|
|
config = "configs/pipelines/stage_2_base.py" |
|
|
config = Config.fromfile(config) |
|
|
model = BUILDER.build(config.model).eval() |
|
|
checkpoint_path = "checkpoints/Puffin-Base.pth" |
|
|
state_dict = guess_load_checkpoint(checkpoint_path) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
model = model.to(torch.bfloat16).cuda() |
|
|
else: |
|
|
model = model.to(torch.float32) |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
@spaces.GPU(duration=120) |
|
|
|
|
|
def multimodal_understanding(image_src, question, seed, progress=gr.Progress(track_tqdm=True)): |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(torch.cuda.is_available()) |
|
|
|
|
|
prompt = ("Describe the image in detail. Then reason its spatial distribution and estimate its camera parameters (roll, pitch, and field-of-view).") |
|
|
|
|
|
image = Image.fromarray(image_src).convert('RGB') |
|
|
image = center_crop(image) |
|
|
image = image.resize((512, 512)) |
|
|
x = torch.from_numpy(np.array(image)).float() |
|
|
x = x / 255.0 |
|
|
x = 2 * x - 1 |
|
|
x = rearrange(x, 'h w c -> c h w') |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model.understand(prompt=[prompt], pixel_values=[x], progress_bar=False) |
|
|
|
|
|
text = outputs[0] |
|
|
|
|
|
gen = Cam_Generator(mode="base") |
|
|
cam = gen.get_cam(text) |
|
|
|
|
|
bgr = np.array(image)[:, :, ::-1].astype(np.float32) / 255.0 |
|
|
rgb = bgr[:, :, ::-1].copy() |
|
|
image_tensor = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0) |
|
|
single_batch = {} |
|
|
single_batch["image"] = image_tensor |
|
|
single_batch["up_field"] = cam[:2].unsqueeze(0) |
|
|
single_batch["latitude_field"] = cam[2:].unsqueeze(0) |
|
|
|
|
|
figs = make_perspective_figures(single_batch, single_batch, n_pairs=1) |
|
|
for k, fig in figs.items(): |
|
|
if "up_field" in k: |
|
|
suffix = "_up" |
|
|
elif "latitude_field" in k: |
|
|
suffix = "_lat" |
|
|
else: |
|
|
suffix = f"_{k}" |
|
|
out_path = os.path.join(save_dir, f"{stem}_camera_map_vis{suffix}.png") |
|
|
plt.tight_layout() |
|
|
fig.savefig(out_path, dpi=200, bbox_inches='tight', pad_inches=0) |
|
|
plt.close(fig) |
|
|
|
|
|
return text |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
@spaces.GPU(duration=120) |
|
|
def generate_image(prompt_scene, |
|
|
seed=42, |
|
|
roll=3, |
|
|
pitch=1.0, |
|
|
fov=1.0, |
|
|
progress=gr.Progress(track_tqdm=True)): |
|
|
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
print(torch.cuda.is_available()) |
|
|
|
|
|
generator = torch.Generator().manual_seed(seed) |
|
|
prompt_camera = ( |
|
|
"The camera parameters (roll, pitch, and field-of-view) are: " |
|
|
f"{roll:.4f}, {pitch:.4f}, {fov:.4f}." |
|
|
) |
|
|
gen = Cam_Generator() |
|
|
cam_map = gen.get_cam(prompt_camera).to(model.device) |
|
|
cam_map = cam_map / (math.pi / 2) |
|
|
|
|
|
prompt = prompt_scene + " " + prompt_camera |
|
|
|
|
|
|
|
|
bsz = 4 |
|
|
with torch.no_grad(): |
|
|
images, output_reasoning = model.generate( |
|
|
prompt=[prompt]*bsz, |
|
|
cfg_prompt=[""]*bsz, |
|
|
pixel_values_init=None, |
|
|
cfg_scale=4.5, |
|
|
num_steps=50, |
|
|
cam_values=[[cam_map]]*bsz, |
|
|
progress_bar=False, |
|
|
reasoning=False, |
|
|
prompt_reasoning=[""]*bsz, |
|
|
generator=generator, |
|
|
height=512, |
|
|
width=512 |
|
|
) |
|
|
|
|
|
images = rearrange(images, 'b c h w -> b h w c') |
|
|
images = torch.clamp(127.5 * images + 128.0, 0, 255).to("cpu", dtype=torch.uint8).numpy() |
|
|
ret_images = [Image.fromarray(image) for image in images] |
|
|
return ret_images |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
css = ''' |
|
|
.gradio-container {max-width: 960px !important} |
|
|
''' |
|
|
with gr.Blocks(css=css) as demo: |
|
|
gr.Markdown("# Puffin") |
|
|
|
|
|
with gr.Tab("Camera-controllable Image Generation"): |
|
|
gr.Markdown(value="## Camera-controllable Image Generation") |
|
|
|
|
|
prompt_input = gr.Textbox(label="Prompt.") |
|
|
|
|
|
with gr.Accordion("Camera Parameters", open=True): |
|
|
with gr.Row(): |
|
|
roll = gr.Slider(minimum=-0.7854, maximum=0.7854, value=0.1000, step=0.1000, label="roll value") |
|
|
pitch = gr.Slider(minimum=-0.7854, maximum=0.7854, value=-0.1000, step=0.1000, label="pitch value") |
|
|
fov = gr.Slider(minimum=0.3491, maximum=1.8326, value=1.5000, step=0.1000, label="fov value") |
|
|
seed_input = gr.Number(label="Seed (Optional)", precision=0, value=1234) |
|
|
|
|
|
generation_button = gr.Button("Generate Images") |
|
|
|
|
|
image_output = gr.Gallery(label="Generated Images", columns=4, rows=1) |
|
|
|
|
|
examples_t2i = gr.Examples( |
|
|
label="Prompt examples.", |
|
|
examples=[ |
|
|
"A sunny day casts light on two warmly colored buildings—yellow with green accents and deeper orange—framed by a lush green tree, with a blue sign and street lamp adding details in the foreground.", |
|
|
"A high-vantage-point view of lush, autumn-colored mountains blanketed in green and gold, set against a clear blue sky with scattered white clouds, offering a tranquil and breathtaking vista of a serene valley below.", |
|
|
"A grand, historic castle with pointed spires and elaborate stone structures stands against a clear blue sky, flanked by a circular fountain, vibrant red flowers, and neatly trimmed hedges in a beautifully landscaped garden.", |
|
|
"A serene aerial view of a coastal landscape at sunrise/sunset, featuring warm pink and orange skies transitioning to cool blues, with calm waters stretching to rugged, snow-capped mountains in the background, creating a tranquil and picturesque scene.", |
|
|
"A worn, light-yellow walls room with herringbone terracotta floors and three large arched windows framed in pink trim and white panes, showcasing signs of age and disrepair, overlooks a residential area through glimpses of greenery and neighboring buildings.", |
|
|
], |
|
|
inputs=prompt_input, |
|
|
) |
|
|
|
|
|
with gr.Tab("Multimodal Understanding"): |
|
|
gr.Markdown(value="## Multimodal Understanding") |
|
|
image_input = gr.Image() |
|
|
with gr.Column(): |
|
|
question_input = gr.Textbox(label="Question") |
|
|
|
|
|
understanding_button = gr.Button("Chat") |
|
|
understanding_output = gr.Textbox(label="Response") |
|
|
|
|
|
with gr.Accordion("Advanced options", open=False): |
|
|
und_seed_input = gr.Number(label="Seed", precision=0, value=42) |
|
|
top_p = gr.Slider(minimum=0, maximum=1, value=0.95, step=0.05, label="top_p") |
|
|
temperature = gr.Slider(minimum=0, maximum=1, value=0.1, step=0.05, label="temperature") |
|
|
|
|
|
examples_inpainting = gr.Examples( |
|
|
label="Multimodal Understanding examples", |
|
|
examples=[ |
|
|
[ |
|
|
"Is the picture taken in winter?", |
|
|
"view.jpg", |
|
|
], |
|
|
[ |
|
|
"Briefly describe the image.", |
|
|
"view.jpg", |
|
|
], |
|
|
], |
|
|
inputs=[question_input, image_input], |
|
|
) |
|
|
|
|
|
generation_button.click( |
|
|
fn=generate_image, |
|
|
inputs=[prompt_input, seed_input, roll, pitch, fov], |
|
|
outputs=image_output |
|
|
) |
|
|
|
|
|
understanding_button.click( |
|
|
multimodal_understanding, |
|
|
inputs=[image_input, question_input, und_seed_input, top_p, temperature], |
|
|
outputs=understanding_output |
|
|
) |
|
|
|
|
|
demo.launch(share=True) |