import torch import gradio as gr from PIL import Image from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor # --------------------------------------------------------------------------- # Configuration # --------------------------------------------------------------------------- MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 DEFAULT_PROMPT = ( "Do you see any abnormality in the chest? Write briefly. " "If yes, also tell where the abnormality is in which part of the chest. " "The chest parts include lungs, heart and vessels, spine, diaphragm, " "soft tissues, Mediastinum and bones of chest shown in image. " "Respond only in English. Do NOT use any other language. " "**Do not use Chinese language.**" ) # --------------------------------------------------------------------------- # Load model & processor # --------------------------------------------------------------------------- print(f"Loading model: {MODEL_ID}") model = Qwen2_5_VLForConditionalGeneration.from_pretrained( MODEL_ID, torch_dtype=DTYPE, device_map="auto" if torch.cuda.is_available() else None, ) if not torch.cuda.is_available(): model = model.to(DEVICE) # Skip video_processor attribute to avoid torchvision dependency _orig_attrs = Qwen2_5_VLProcessor.attributes[:] Qwen2_5_VLProcessor.attributes = [a for a in _orig_attrs if a != "video_processor"] processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID) Qwen2_5_VLProcessor.attributes = _orig_attrs print("Model loaded successfully.") # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def pad_to_square(image: Image.Image) -> Image.Image: width, height = image.size if width == height: return image max_dim = max(width, height) new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) new_image.paste(image, ((max_dim - width) // 2, (max_dim - height) // 2)) return new_image # --------------------------------------------------------------------------- # Inference # --------------------------------------------------------------------------- def predict(image: Image.Image, prompt: str, max_new_tokens: int, temperature: float): if image is None: return "Please upload a chest X-ray image." if image.mode != "RGB": image = image.convert("RGB") image = pad_to_square(image) if not prompt.strip(): prompt = DEFAULT_PROMPT messages = [ { "role": "user", "content": [ {"type": "image", "image": image}, {"type": "text", "text": prompt}, ], } ] text = processor.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = processor( text=[text], images=[image], return_tensors="pt", padding=True ).to(model.device) with torch.no_grad(): generated_ids = model.generate( **inputs, max_new_tokens=int(max_new_tokens), do_sample=temperature > 0, temperature=temperature if temperature > 0 else 1.0, ) generated_ids_trimmed = [ out_ids[len(in_ids):] for in_ids, out_ids in zip(inputs.input_ids, generated_ids) ] return processor.batch_decode( generated_ids_trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False, )[0] # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- with gr.Blocks( title="Chest X-Ray Analysis — Qwen2.5-VL-3B", theme=gr.themes.Soft(), ) as demo: gr.Markdown( "# Chest X-Ray Analysis\n" "Upload a chest X-ray and get an automated report " "powered by **Qwen2.5-VL-3B-Instruct**." ) with gr.Row(): with gr.Column(scale=1): image_input = gr.Image(type="pil", label="Upload Chest X-Ray") prompt_input = gr.Textbox( label="Prompt", value=DEFAULT_PROMPT, lines=4, ) with gr.Row(): max_tokens_slider = gr.Slider( minimum=64, maximum=1024, value=512, step=64, label="Max New Tokens", ) temperature_slider = gr.Slider( minimum=0.0, maximum=1.5, value=0.3, step=0.05, label="Temperature (0 = greedy)", ) submit_btn = gr.Button("Analyze", variant="primary") with gr.Column(scale=1): output_text = gr.Textbox(label="Model Report", lines=20) submit_btn.click( predict, inputs=[image_input, prompt_input, max_tokens_slider, temperature_slider], outputs=output_text, ) gr.Markdown( "---\n" "*Research purposes only — not a substitute for professional medical diagnosis.*" ) if __name__ == "__main__": demo.launch()