Spaces:
Sleeping
Sleeping
Florian commited on
Commit ·
0c57df9
1
Parent(s): f04ff47
first commit
Browse files
app.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from PIL import Image
|
| 6 |
+
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
|
| 7 |
+
|
| 8 |
+
os.environ["GRADIO_TEMP_DIR"] = (
|
| 9 |
+
"/home/agent_vision@BEIJAFLORE.COM/fmorel/CoVT-main/CoVT-main/gradio/temp"
|
| 10 |
+
)
|
| 11 |
+
import gradio as gr
|
| 12 |
+
|
| 13 |
+
# ================= Configuration Area =================
|
| 14 |
+
# You can change these defaults as you like
|
| 15 |
+
DEFAULT_MODEL_NAME = "Wakals/CoVT-7B-seg_depth_dino"
|
| 16 |
+
DEFAULT_CKPT_PATH = None # Or set to your local checkpoint path
|
| 17 |
+
# ======================================================
|
| 18 |
+
|
| 19 |
+
# Global cache for model and processor to avoid re-loading every call
|
| 20 |
+
_cached_model = None
|
| 21 |
+
_cached_processor = None
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def load_model_and_processor(
|
| 25 |
+
model_name: str,
|
| 26 |
+
ckpt: str = None,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
Load a single CoVT-7B model and its corresponding processor.
|
| 30 |
+
"""
|
| 31 |
+
if ckpt is not None:
|
| 32 |
+
print(f"Loading model from ckpt: {ckpt}")
|
| 33 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 34 |
+
ckpt, torch_dtype=torch.bfloat16, device_map="auto"
|
| 35 |
+
).eval()
|
| 36 |
+
processor = AutoProcessor.from_pretrained(
|
| 37 |
+
ckpt, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
|
| 38 |
+
)
|
| 39 |
+
else:
|
| 40 |
+
print(f"Loading model from hub: {model_name}")
|
| 41 |
+
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 42 |
+
model_name, torch_dtype=torch.bfloat16, device_map="auto"
|
| 43 |
+
).eval()
|
| 44 |
+
processor = AutoProcessor.from_pretrained(
|
| 45 |
+
model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
return model, processor
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_cached_model_and_processor(
|
| 52 |
+
model_name: str = DEFAULT_MODEL_NAME,
|
| 53 |
+
ckpt: str = DEFAULT_CKPT_PATH,
|
| 54 |
+
):
|
| 55 |
+
"""
|
| 56 |
+
Lazy-load and cache the model and processor so they are not reloaded every request.
|
| 57 |
+
"""
|
| 58 |
+
global _cached_model, _cached_processor
|
| 59 |
+
|
| 60 |
+
# If already loaded, just return them
|
| 61 |
+
if _cached_model is not None and _cached_processor is not None:
|
| 62 |
+
return _cached_model, _cached_processor
|
| 63 |
+
|
| 64 |
+
# Otherwise load and cache
|
| 65 |
+
_cached_model, _cached_processor = load_model_and_processor(
|
| 66 |
+
model_name=model_name,
|
| 67 |
+
ckpt=ckpt,
|
| 68 |
+
)
|
| 69 |
+
return _cached_model, _cached_processor
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def run_single_inference(
|
| 73 |
+
model,
|
| 74 |
+
processor,
|
| 75 |
+
image, # can be either a PIL.Image or a path string
|
| 76 |
+
question: str,
|
| 77 |
+
max_new_tokens: int = 512,
|
| 78 |
+
temperature: float = 0.0,
|
| 79 |
+
top_p: float = 0.9,
|
| 80 |
+
do_sample: bool = False,
|
| 81 |
+
seed: int = 42,
|
| 82 |
+
):
|
| 83 |
+
"""
|
| 84 |
+
Single inference: given one image and one question, return answer and elapsed time.
|
| 85 |
+
"""
|
| 86 |
+
# 1) Prepare conversation
|
| 87 |
+
# For Gradio we usually get a PIL image, but we also support a path string for compatibility.
|
| 88 |
+
if isinstance(image, str):
|
| 89 |
+
pil_image = Image.open(image).convert("RGB")
|
| 90 |
+
image_ref = image # path for the "image" field
|
| 91 |
+
elif isinstance(image, Image.Image):
|
| 92 |
+
pil_image = image.convert("RGB")
|
| 93 |
+
# When using PIL image in chat template, you can pass a placeholder
|
| 94 |
+
# and rely on 'images' argument in processor; here we still need a "dummy" reference.
|
| 95 |
+
image_ref = (
|
| 96 |
+
"gradio_image" # this is not used as a real path, just a placeholder
|
| 97 |
+
)
|
| 98 |
+
else:
|
| 99 |
+
raise ValueError("image must be a PIL.Image or a path string.")
|
| 100 |
+
|
| 101 |
+
messages = [
|
| 102 |
+
{
|
| 103 |
+
"role": "user",
|
| 104 |
+
"content": [
|
| 105 |
+
{"type": "image", "image": image_ref},
|
| 106 |
+
{"type": "text", "text": question},
|
| 107 |
+
],
|
| 108 |
+
}
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
# 2) Apply chat template
|
| 112 |
+
prompt = processor.apply_chat_template(
|
| 113 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
# 3) Encode image and text
|
| 117 |
+
inputs = processor(text=[prompt], images=[pil_image], return_tensors="pt")
|
| 118 |
+
|
| 119 |
+
# Move inputs to the same device as the model
|
| 120 |
+
device = model.device
|
| 121 |
+
inputs = {
|
| 122 |
+
k: (v.to(device) if isinstance(v, torch.Tensor) else v)
|
| 123 |
+
for k, v in inputs.items()
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
# 3.5) Set random seed and generator for reproducibility when sampling
|
| 127 |
+
seed = int(seed)
|
| 128 |
+
torch.manual_seed(seed)
|
| 129 |
+
if torch.cuda.is_available():
|
| 130 |
+
torch.cuda.manual_seed_all(seed)
|
| 131 |
+
try:
|
| 132 |
+
generator = torch.Generator(device=device)
|
| 133 |
+
except TypeError:
|
| 134 |
+
generator = torch.Generator()
|
| 135 |
+
generator.manual_seed(seed)
|
| 136 |
+
|
| 137 |
+
# 4) Timing + generation
|
| 138 |
+
if device.type == "cuda":
|
| 139 |
+
torch.cuda.empty_cache()
|
| 140 |
+
torch.cuda.synchronize()
|
| 141 |
+
|
| 142 |
+
start = time.time()
|
| 143 |
+
with torch.no_grad():
|
| 144 |
+
generated_ids = model.generate(
|
| 145 |
+
**inputs,
|
| 146 |
+
max_new_tokens=max_new_tokens,
|
| 147 |
+
temperature=temperature,
|
| 148 |
+
top_p=top_p,
|
| 149 |
+
do_sample=do_sample,
|
| 150 |
+
generator=generator,
|
| 151 |
+
pad_token_id=processor.tokenizer.eos_token_id,
|
| 152 |
+
eos_token_id=processor.tokenizer.eos_token_id,
|
| 153 |
+
)
|
| 154 |
+
if device.type == "cuda":
|
| 155 |
+
torch.cuda.synchronize()
|
| 156 |
+
end = time.time()
|
| 157 |
+
|
| 158 |
+
elapsed = end - start
|
| 159 |
+
|
| 160 |
+
# 5) Decode only newly generated tokens
|
| 161 |
+
input_len = inputs["input_ids"].shape[1]
|
| 162 |
+
new_tokens = generated_ids[0, input_len:]
|
| 163 |
+
answer = processor.decode(new_tokens, skip_special_tokens=True)
|
| 164 |
+
|
| 165 |
+
return answer, elapsed
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def gradio_inference(
|
| 169 |
+
image,
|
| 170 |
+
question,
|
| 171 |
+
max_new_tokens,
|
| 172 |
+
temperature,
|
| 173 |
+
top_p,
|
| 174 |
+
seed,
|
| 175 |
+
):
|
| 176 |
+
"""
|
| 177 |
+
Wrapper function for Gradio that calls the inference logic and returns answer + time cost.
|
| 178 |
+
"""
|
| 179 |
+
if image is None:
|
| 180 |
+
return "Please upload an image.", 0.0
|
| 181 |
+
|
| 182 |
+
# Get (or load) model and processor
|
| 183 |
+
model, processor = get_cached_model_and_processor()
|
| 184 |
+
|
| 185 |
+
# Run inference
|
| 186 |
+
answer, elapsed = run_single_inference(
|
| 187 |
+
model=model,
|
| 188 |
+
processor=processor,
|
| 189 |
+
image=image, # filepath string from Gradio
|
| 190 |
+
question=question,
|
| 191 |
+
max_new_tokens=int(max_new_tokens),
|
| 192 |
+
temperature=float(temperature),
|
| 193 |
+
top_p=float(top_p),
|
| 194 |
+
do_sample=(temperature > 0.0),
|
| 195 |
+
seed=int(seed),
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
return answer, elapsed
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ===================== Gradio UI =====================
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def build_demo():
|
| 205 |
+
with gr.Blocks() as demo:
|
| 206 |
+
gr.Markdown(
|
| 207 |
+
"# CoVT-7B Gradio Demo\n"
|
| 208 |
+
"Upload an image and input a question to run visual question answering."
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
with gr.Row():
|
| 212 |
+
with gr.Column():
|
| 213 |
+
image_input = gr.Image(label="Input Image", type="pil")
|
| 214 |
+
question_input = gr.Textbox(label="Question", value="", lines=2)
|
| 215 |
+
max_new_tokens = gr.Slider(
|
| 216 |
+
label="max_new_tokens", minimum=1, maximum=1024, value=512, step=1
|
| 217 |
+
)
|
| 218 |
+
temperature = gr.Slider(
|
| 219 |
+
label="temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.01
|
| 220 |
+
)
|
| 221 |
+
top_p = gr.Slider(
|
| 222 |
+
label="top_p", minimum=0.1, maximum=1.0, value=0.9, step=0.01
|
| 223 |
+
)
|
| 224 |
+
seed = gr.Slider(
|
| 225 |
+
label="random_seed", minimum=0, maximum=1000, value=42, step=1
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
gr.Markdown("### Example")
|
| 229 |
+
example_image_path = os.path.abspath(
|
| 230 |
+
os.path.join(
|
| 231 |
+
os.path.dirname(__file__), "..", "assets", "clouds.png"
|
| 232 |
+
)
|
| 233 |
+
)
|
| 234 |
+
example_image = Image.open(example_image_path).convert("RGB")
|
| 235 |
+
gr.Examples(
|
| 236 |
+
examples=[
|
| 237 |
+
[
|
| 238 |
+
example_image,
|
| 239 |
+
"Describe the scene in the picture in detail, and find out how many clouds are in the sky. Use segmentation, depth map, and perception feature information of the image to answer this question.",
|
| 240 |
+
]
|
| 241 |
+
],
|
| 242 |
+
inputs=[image_input, question_input],
|
| 243 |
+
examples_per_page=1,
|
| 244 |
+
)
|
| 245 |
+
# -----------------------------------------
|
| 246 |
+
|
| 247 |
+
run_button = gr.Button("Run Inference")
|
| 248 |
+
|
| 249 |
+
with gr.Column():
|
| 250 |
+
answer_output = gr.Textbox(label="Answer", lines=10)
|
| 251 |
+
elapsed_output = gr.Number(label="Elapsed time (seconds)")
|
| 252 |
+
|
| 253 |
+
run_button.click(
|
| 254 |
+
fn=gradio_inference,
|
| 255 |
+
inputs=[
|
| 256 |
+
image_input,
|
| 257 |
+
question_input,
|
| 258 |
+
max_new_tokens,
|
| 259 |
+
temperature,
|
| 260 |
+
top_p,
|
| 261 |
+
seed,
|
| 262 |
+
],
|
| 263 |
+
outputs=[answer_output, elapsed_output],
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
return demo
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
if __name__ == "__main__":
|
| 270 |
+
demo = build_demo()
|
| 271 |
+
# You can set share=True if you want a public link
|
| 272 |
+
demo.launch()
|