test / app.py
aqibmumtazbits's picture
Create app.py
6e89446 verified
import torch
import gradio as gr
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
DEFAULT_PROMPT = (
"Do you see any abnormality in the chest? Write briefly. "
"If yes, also tell where the abnormality is in which part of the chest. "
"The chest parts include lungs, heart and vessels, spine, diaphragm, "
"soft tissues, Mediastinum and bones of chest shown in image. "
"Respond only in English. Do NOT use any other language. "
"**Do not use Chinese language.**"
)
# ---------------------------------------------------------------------------
# Load model & processor
# ---------------------------------------------------------------------------
print(f"Loading model: {MODEL_ID}")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
device_map="auto" if torch.cuda.is_available() else None,
)
if not torch.cuda.is_available():
model = model.to(DEVICE)
# Skip video_processor attribute to avoid torchvision dependency
_orig_attrs = Qwen2_5_VLProcessor.attributes[:]
Qwen2_5_VLProcessor.attributes = [a for a in _orig_attrs if a != "video_processor"]
processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID)
Qwen2_5_VLProcessor.attributes = _orig_attrs
print("Model loaded successfully.")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def pad_to_square(image: Image.Image) -> Image.Image:
width, height = image.size
if width == height:
return image
max_dim = max(width, height)
new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
new_image.paste(image, ((max_dim - width) // 2, (max_dim - height) // 2))
return new_image
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def predict(image: Image.Image, prompt: str, max_new_tokens: int, temperature: float):
if image is None:
return "Please upload a chest X-ray image."
if image.mode != "RGB":
image = image.convert("RGB")
image = pad_to_square(image)
if not prompt.strip():
prompt = DEFAULT_PROMPT
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text], images=[image], return_tensors="pt", padding=True
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else 1.0,
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(
title="Chest X-Ray Analysis — Qwen2.5-VL-3B",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"# Chest X-Ray Analysis\n"
"Upload a chest X-ray and get an automated report "
"powered by **Qwen2.5-VL-3B-Instruct**."
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Chest X-Ray")
prompt_input = gr.Textbox(
label="Prompt",
value=DEFAULT_PROMPT,
lines=4,
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=64, maximum=1024, value=512, step=64,
label="Max New Tokens",
)
temperature_slider = gr.Slider(
minimum=0.0, maximum=1.5, value=0.3, step=0.05,
label="Temperature (0 = greedy)",
)
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(label="Model Report", lines=20)
submit_btn.click(
predict,
inputs=[image_input, prompt_input, max_tokens_slider, temperature_slider],
outputs=output_text,
)
gr.Markdown(
"---\n"
"*Research purposes only — not a substitute for professional medical diagnosis.*"
)
if __name__ == "__main__":
demo.launch()