Spaces:
No application file
No application file
File size: 5,244 Bytes
6e89446 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import torch
import gradio as gr
from PIL import Image
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
DEFAULT_PROMPT = (
"Do you see any abnormality in the chest? Write briefly. "
"If yes, also tell where the abnormality is in which part of the chest. "
"The chest parts include lungs, heart and vessels, spine, diaphragm, "
"soft tissues, Mediastinum and bones of chest shown in image. "
"Respond only in English. Do NOT use any other language. "
"**Do not use Chinese language.**"
)
# ---------------------------------------------------------------------------
# Load model & processor
# ---------------------------------------------------------------------------
print(f"Loading model: {MODEL_ID}")
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
MODEL_ID,
torch_dtype=DTYPE,
device_map="auto" if torch.cuda.is_available() else None,
)
if not torch.cuda.is_available():
model = model.to(DEVICE)
# Skip video_processor attribute to avoid torchvision dependency
_orig_attrs = Qwen2_5_VLProcessor.attributes[:]
Qwen2_5_VLProcessor.attributes = [a for a in _orig_attrs if a != "video_processor"]
processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID)
Qwen2_5_VLProcessor.attributes = _orig_attrs
print("Model loaded successfully.")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def pad_to_square(image: Image.Image) -> Image.Image:
width, height = image.size
if width == height:
return image
max_dim = max(width, height)
new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0))
new_image.paste(image, ((max_dim - width) // 2, (max_dim - height) // 2))
return new_image
# ---------------------------------------------------------------------------
# Inference
# ---------------------------------------------------------------------------
def predict(image: Image.Image, prompt: str, max_new_tokens: int, temperature: float):
if image is None:
return "Please upload a chest X-ray image."
if image.mode != "RGB":
image = image.convert("RGB")
image = pad_to_square(image)
if not prompt.strip():
prompt = DEFAULT_PROMPT
messages = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": prompt},
],
}
]
text = processor.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = processor(
text=[text], images=[image], return_tensors="pt", padding=True
).to(model.device)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
do_sample=temperature > 0,
temperature=temperature if temperature > 0 else 1.0,
)
generated_ids_trimmed = [
out_ids[len(in_ids):]
for in_ids, out_ids in zip(inputs.input_ids, generated_ids)
]
return processor.batch_decode(
generated_ids_trimmed,
skip_special_tokens=True,
clean_up_tokenization_spaces=False,
)[0]
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
with gr.Blocks(
title="Chest X-Ray Analysis — Qwen2.5-VL-3B",
theme=gr.themes.Soft(),
) as demo:
gr.Markdown(
"# Chest X-Ray Analysis\n"
"Upload a chest X-ray and get an automated report "
"powered by **Qwen2.5-VL-3B-Instruct**."
)
with gr.Row():
with gr.Column(scale=1):
image_input = gr.Image(type="pil", label="Upload Chest X-Ray")
prompt_input = gr.Textbox(
label="Prompt",
value=DEFAULT_PROMPT,
lines=4,
)
with gr.Row():
max_tokens_slider = gr.Slider(
minimum=64, maximum=1024, value=512, step=64,
label="Max New Tokens",
)
temperature_slider = gr.Slider(
minimum=0.0, maximum=1.5, value=0.3, step=0.05,
label="Temperature (0 = greedy)",
)
submit_btn = gr.Button("Analyze", variant="primary")
with gr.Column(scale=1):
output_text = gr.Textbox(label="Model Report", lines=20)
submit_btn.click(
predict,
inputs=[image_input, prompt_input, max_tokens_slider, temperature_slider],
outputs=output_text,
)
gr.Markdown(
"---\n"
"*Research purposes only — not a substitute for professional medical diagnosis.*"
)
if __name__ == "__main__":
demo.launch()
|