Spaces:
No application file
No application file
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor | |
| # --------------------------------------------------------------------------- | |
| # Configuration | |
| # --------------------------------------------------------------------------- | |
| MODEL_ID = "Qwen/Qwen2.5-VL-3B-Instruct" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32 | |
| DEFAULT_PROMPT = ( | |
| "Do you see any abnormality in the chest? Write briefly. " | |
| "If yes, also tell where the abnormality is in which part of the chest. " | |
| "The chest parts include lungs, heart and vessels, spine, diaphragm, " | |
| "soft tissues, Mediastinum and bones of chest shown in image. " | |
| "Respond only in English. Do NOT use any other language. " | |
| "**Do not use Chinese language.**" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Load model & processor | |
| # --------------------------------------------------------------------------- | |
| print(f"Loading model: {MODEL_ID}") | |
| model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| MODEL_ID, | |
| torch_dtype=DTYPE, | |
| device_map="auto" if torch.cuda.is_available() else None, | |
| ) | |
| if not torch.cuda.is_available(): | |
| model = model.to(DEVICE) | |
| # Skip video_processor attribute to avoid torchvision dependency | |
| _orig_attrs = Qwen2_5_VLProcessor.attributes[:] | |
| Qwen2_5_VLProcessor.attributes = [a for a in _orig_attrs if a != "video_processor"] | |
| processor = Qwen2_5_VLProcessor.from_pretrained(MODEL_ID) | |
| Qwen2_5_VLProcessor.attributes = _orig_attrs | |
| print("Model loaded successfully.") | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def pad_to_square(image: Image.Image) -> Image.Image: | |
| width, height = image.size | |
| if width == height: | |
| return image | |
| max_dim = max(width, height) | |
| new_image = Image.new("RGB", (max_dim, max_dim), (0, 0, 0)) | |
| new_image.paste(image, ((max_dim - width) // 2, (max_dim - height) // 2)) | |
| return new_image | |
| # --------------------------------------------------------------------------- | |
| # Inference | |
| # --------------------------------------------------------------------------- | |
| def predict(image: Image.Image, prompt: str, max_new_tokens: int, temperature: float): | |
| if image is None: | |
| return "Please upload a chest X-ray image." | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| image = pad_to_square(image) | |
| if not prompt.strip(): | |
| prompt = DEFAULT_PROMPT | |
| messages = [ | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": image}, | |
| {"type": "text", "text": prompt}, | |
| ], | |
| } | |
| ] | |
| text = processor.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True | |
| ) | |
| inputs = processor( | |
| text=[text], images=[image], return_tensors="pt", padding=True | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| generated_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=int(max_new_tokens), | |
| do_sample=temperature > 0, | |
| temperature=temperature if temperature > 0 else 1.0, | |
| ) | |
| generated_ids_trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated_ids) | |
| ] | |
| return processor.batch_decode( | |
| generated_ids_trimmed, | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| )[0] | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| with gr.Blocks( | |
| title="Chest X-Ray Analysis — Qwen2.5-VL-3B", | |
| theme=gr.themes.Soft(), | |
| ) as demo: | |
| gr.Markdown( | |
| "# Chest X-Ray Analysis\n" | |
| "Upload a chest X-ray and get an automated report " | |
| "powered by **Qwen2.5-VL-3B-Instruct**." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload Chest X-Ray") | |
| prompt_input = gr.Textbox( | |
| label="Prompt", | |
| value=DEFAULT_PROMPT, | |
| lines=4, | |
| ) | |
| with gr.Row(): | |
| max_tokens_slider = gr.Slider( | |
| minimum=64, maximum=1024, value=512, step=64, | |
| label="Max New Tokens", | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, maximum=1.5, value=0.3, step=0.05, | |
| label="Temperature (0 = greedy)", | |
| ) | |
| submit_btn = gr.Button("Analyze", variant="primary") | |
| with gr.Column(scale=1): | |
| output_text = gr.Textbox(label="Model Report", lines=20) | |
| submit_btn.click( | |
| predict, | |
| inputs=[image_input, prompt_input, max_tokens_slider, temperature_slider], | |
| outputs=output_text, | |
| ) | |
| gr.Markdown( | |
| "---\n" | |
| "*Research purposes only — not a substitute for professional medical diagnosis.*" | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |