CoVT / app.py
Florian
rm images
128ad5a
import time
# os.environ["GRADIO_TEMP_DIR"] = (
# "/home/agent_vision@BEIJAFLORE.COM/fmorel/CoVT-main/CoVT-main/gradio/temp"
# )
import gradio as gr
import torch
from PIL import Image
from transformers import AutoProcessor, Qwen2_5_VLForConditionalGeneration
# ================= Configuration Area =================
# You can change these defaults as you like
DEFAULT_MODEL_NAME = "Wakals/CoVT-7B-seg_depth_dino"
DEFAULT_CKPT_PATH = None # Or set to your local checkpoint path
# ======================================================
# Global cache for model and processor to avoid re-loading every call
_cached_model = None
_cached_processor = None
def load_model_and_processor(
model_name: str,
ckpt: str = None,
):
"""
Load a single CoVT-7B model and its corresponding processor.
"""
if ckpt is not None:
print(f"Loading model from ckpt: {ckpt}")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
ckpt, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(
ckpt, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
)
else:
print(f"Loading model from hub: {model_name}")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
model_name, torch_dtype=torch.bfloat16, device_map="auto"
).eval()
processor = AutoProcessor.from_pretrained(
model_name, min_pixels=256 * 28 * 28, max_pixels=1280 * 28 * 28
)
return model, processor
def get_cached_model_and_processor(
model_name: str = DEFAULT_MODEL_NAME,
ckpt: str = DEFAULT_CKPT_PATH,
):
"""
Lazy-load and cache the model and processor so they are not reloaded every request.
"""
global _cached_model, _cached_processor
# If already loaded, just return them
if _cached_model is not None and _cached_processor is not None:
return _cached_model, _cached_processor
# Otherwise load and cache
_cached_model, _cached_processor = load_model_and_processor(
model_name=model_name,
ckpt=ckpt,
)
return _cached_model, _cached_processor
def run_single_inference(
model,
processor,
image, # can be either a PIL.Image or a path string
question: str,
max_new_tokens: int = 512,
temperature: float = 0.0,
top_p: float = 0.9,
do_sample: bool = False,
seed: int = 42,
):
"""
Single inference: given one image and one question, return answer and elapsed time.
"""
# 1) Prepare conversation
# For Gradio we usually get a PIL image, but we also support a path string for compatibility.
if isinstance(image, str):
pil_image = Image.open(image).convert("RGB")
image_ref = image # path for the "image" field
elif isinstance(image, Image.Image):
pil_image = image.convert("RGB")
# When using PIL image in chat template, you can pass a placeholder
# and rely on 'images' argument in processor; here we still need a "dummy" reference.
image_ref = (
"gradio_image" # this is not used as a real path, just a placeholder
)
else:
raise ValueError("image must be a PIL.Image or a path string.")
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image_ref},
{"type": "text", "text": question},
],
}
]
# 2) Apply chat template
prompt = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
# 3) Encode image and text
inputs = processor(text=[prompt], images=[pil_image], return_tensors="pt")
# Move inputs to the same device as the model
device = model.device
inputs = {
k: (v.to(device) if isinstance(v, torch.Tensor) else v)
for k, v in inputs.items()
}
print(">>>>>>>>>>>>< DEVICE ", device.type)
# 4) Timing + generation
if device.type == "cuda":
torch.cuda.empty_cache()
torch.cuda.synchronize()
start = time.time()
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_p=top_p,
do_sample=do_sample,
pad_token_id=processor.tokenizer.eos_token_id,
eos_token_id=processor.tokenizer.eos_token_id,
)
if device.type == "cuda":
torch.cuda.synchronize()
end = time.time()
elapsed = end - start
# 5) Decode only newly generated tokens
input_len = inputs["input_ids"].shape[1]
new_tokens = generated_ids[0, input_len:]
answer = processor.decode(new_tokens, skip_special_tokens=True)
return answer, elapsed
def gradio_inference(
image,
question,
max_new_tokens,
temperature,
top_p,
seed,
):
"""
Wrapper function for Gradio that calls the inference logic and returns answer + time cost.
"""
if image is None:
return "Please upload an image.", 0.0
# Get (or load) model and processor
model, processor = get_cached_model_and_processor()
# Run inference
answer, elapsed = run_single_inference(
model=model,
processor=processor,
image=image, # filepath string from Gradio
question=question,
max_new_tokens=int(max_new_tokens),
temperature=float(temperature),
top_p=float(top_p),
do_sample=(temperature > 0.0),
seed=int(seed),
)
return answer, elapsed
# ===================== Gradio UI =====================
def build_demo():
with gr.Blocks() as demo:
gr.Markdown(
"# CoVT-7B Gradio Demo\n"
"Upload an image and input a question to run visual question answering."
)
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Input Image", type="pil")
question_input = gr.Textbox(label="Question", value="", lines=2)
max_new_tokens = gr.Slider(
label="max_new_tokens", minimum=1, maximum=1024, value=512, step=1
)
temperature = gr.Slider(
label="temperature", minimum=0.0, maximum=1.0, value=0.0, step=0.01
)
top_p = gr.Slider(
label="top_p", minimum=0.1, maximum=1.0, value=0.9, step=0.01
)
seed = gr.Slider(
label="random_seed", minimum=0, maximum=1000, value=42, step=1
)
run_button = gr.Button("Run Inference")
with gr.Column():
answer_output = gr.Textbox(label="Answer", lines=10)
elapsed_output = gr.Number(label="Elapsed time (seconds)")
run_button.click(
fn=gradio_inference,
inputs=[
image_input,
question_input,
max_new_tokens,
temperature,
top_p,
seed,
],
outputs=[answer_output, elapsed_output],
)
return demo
if __name__ == "__main__":
demo = build_demo()
# You can set share=True if you want a public link
demo.launch()