Spaces:
Sleeping
Sleeping
| # import os | |
| # os.environ["HF_HOME"] = "/tmp/hf_cache" | |
| import streamlit as st | |
| from PIL import Image | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoProcessor | |
| import numpy as np | |
| import supervision as sv | |
| import albumentations as A | |
| import cv2 | |
| from transformers import AutoConfig | |
| import yaml | |
| # Set Streamlit page configuration for a wide layout | |
| st.set_page_config(layout="wide") | |
| # Custom CSS for better layout and mobile responsiveness | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| max-width: 1200px; /* Max width for content */ | |
| margin: 0 auto; | |
| } | |
| .block-container { | |
| padding-top: 2rem; | |
| padding-bottom: 2rem; | |
| padding-left: 3rem; | |
| padding-right: 3rem; | |
| } | |
| .title { | |
| font-size: 3.2rem; | |
| text-align: center; | |
| background: linear-gradient(135deg, #0575e6 0%, #ff0080 50%, #7928ca 100%); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| background-clip: text; | |
| } | |
| @keyframes gradientShift { | |
| 0% { background-position: 0% 50%; } | |
| 50% { background-position: 100% 50%; } | |
| 100% { background-position: 0% 50%; } | |
| } | |
| .subheader { | |
| font-size: 1.5rem; | |
| margin-bottom: 20px; | |
| } | |
| .btn { | |
| font-size: 1.1rem; | |
| padding: 10px 20px; | |
| background-color: #FF6347; | |
| color: white; | |
| border-radius: 5px; | |
| border: none; | |
| cursor: pointer; | |
| } | |
| .btn:hover { | |
| background-color: #FF4500; | |
| } | |
| .column-spacing { | |
| display: flex; | |
| justify-content: space-between; | |
| } | |
| .col-half { | |
| width: 48%; | |
| } | |
| .col-full { | |
| width: 100%; | |
| } | |
| .instructions { | |
| padding: 20px; | |
| background-color: #f9f9f9; | |
| border-radius: 8px; | |
| box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # # Custom CSS for better layout and mobile responsiveness | |
| # st.markdown(""" | |
| # <style> | |
| # .main { | |
| # max-width: 1200px; /* Max width for content */ | |
| # margin: 0 auto; | |
| # } | |
| # .block-container { | |
| # padding-top: 2rem; | |
| # padding-bottom: 2rem; | |
| # padding-left: 3rem; | |
| # padding-right: 3rem; | |
| # } | |
| # .title { | |
| # font-size: 2.5rem; | |
| # text-align: center; | |
| # color: #FF6347; | |
| # } | |
| # .subheader { | |
| # font-size: 1.5rem; | |
| # margin-bottom: 20px; | |
| # } | |
| # .btn { | |
| # font-size: 1.1rem; | |
| # padding: 10px 20px; | |
| # background-color: #FF6347; | |
| # color: white; | |
| # border-radius: 5px; | |
| # border: none; | |
| # cursor: pointer; | |
| # } | |
| # .btn:hover { | |
| # background-color: #FF4500; | |
| # } | |
| # .column-spacing { | |
| # display: flex; | |
| # justify-content: space-between; | |
| # } | |
| # .col-half { | |
| # width: 48%; | |
| # } | |
| # .col-full { | |
| # width: 100%; | |
| # } | |
| # .instructions { | |
| # padding: 20px; | |
| # background-color: #f9f9f9; | |
| # border-radius: 8px; | |
| # box-shadow: 0 2px 10px rgba(0, 0, 0, 0.1); | |
| # } | |
| # </style> | |
| # """, unsafe_allow_html=True) | |
| # Load Model and Processor | |
| def load_model(): | |
| REVISION = 'refs/pr/6' | |
| MODEL_NAME = "Anonymous-AC/K2Sight-Lite" | |
| # MODEL_NAME = '/u/home/lj0/Checkpoints/AD-KD-MICCAI25' | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| config_model = AutoConfig.from_pretrained ("microsoft/Florence-2-base-ft", trust_remote_code=True) | |
| config_model.vision_config.model_type = "davit" | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, trust_remote_code=True, config=config_model).to(DEVICE) | |
| BASE_PROCESSOR = "microsoft/Florence-2-base-ft" | |
| processor = AutoProcessor.from_pretrained(BASE_PROCESSOR, trust_remote_code=True) | |
| processor.image_processor.size = 512 | |
| processor.image_processor.crop_size = 512 | |
| return model, processor, DEVICE | |
| model, processor, DEVICE = load_model() | |
| # Load Definitions | |
| def load_definitions(): | |
| vindr_path = 'configs/vindr_definition.yaml' | |
| padchest_path = 'configs/padchest_definition.yaml' | |
| prompt_path = 'examples/prompt.yaml' | |
| with open(vindr_path, 'r') as file: | |
| vindr_definitions = yaml.safe_load(file) | |
| with open(padchest_path, 'r') as file: | |
| padchest_definitions = yaml.safe_load(file) | |
| with open(prompt_path, 'r') as file: | |
| prompt_definitions = yaml.safe_load(file) | |
| return vindr_definitions, padchest_definitions, prompt_definitions | |
| vindr_definitions, padchest_definitions, prompt_definitions = load_definitions() | |
| dataset_options = {"Vindr": vindr_definitions, "PadChest": padchest_definitions} | |
| def load_example_images(): | |
| return list(prompt_definitions.keys()) | |
| example_images = load_example_images() | |
| def apply_transform(image, size_mode=512): | |
| pad_resize_transform = A.Compose([ | |
| A.LongestMaxSize(max_size=size_mode, interpolation=cv2.INTER_AREA), | |
| A.PadIfNeeded(min_height=size_mode, min_width=size_mode, border_mode=cv2.BORDER_CONSTANT, value=(0, 0, 0)), | |
| A.Resize(height=512, width=512, interpolation=cv2.INTER_AREA), | |
| ]) | |
| image_np = np.array(image) | |
| transformed = pad_resize_transform(image=image_np) | |
| return transformed["image"] | |
| # Streamlit UI with Colorful Title and Emojis | |
| st.markdown("<h1 class='title'>Knowledge to Sight: Reasoning over Visual Attributes via Knowledge Decomposition for Abnormality Grounding </h1>", unsafe_allow_html=True) | |
| # st.markdown("<h1 class='title'>π©Ί Knowledge to Sight: Reasoning over Visual Attributes via Knowledge Decomposition for Abnormality Grounding π</h1>", unsafe_allow_html=True) | |
| st.markdown( | |
| "<p style='text-align: center; font-size: 18px;'>Welcome to a simple demo of our work! π Choose an example or upload your own image to get started! π</p>", | |
| unsafe_allow_html=True | |
| ) | |
| # Display Example Images First | |
| st.subheader("π Example Images") | |
| selected_example = st.selectbox("Choose an example", example_images) | |
| image = Image.open(selected_example).convert("RGB") | |
| example_diseases = prompt_definitions.get(selected_example, []) | |
| st.write("**Associated Diseases:**", ", ".join(example_diseases)) | |
| # Layout for Original Image and Instructions | |
| col1, col2 = st.columns([1, 2]) | |
| # Left column for original image | |
| with col1: | |
| st.image(image, caption=f"Original Example Image: {selected_example}", width=400) | |
| # Right column for Instructions and Run Inference Button | |
| with col2: | |
| st.subheader("βοΈ Instructions to Get Started:") | |
| st.write(""" | |
| - **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results. | |
| - **Choose an Example**: π Select an example image from the dataset to view its associated diseases. | |
| - **Upload Your Own Image**: π€ Upload an image of your choice to analyze it for diseases. | |
| - **Select Dataset**: π Choose between available datasets (Vindr or PadChest) for disease information. | |
| - **Select Disease**: π¦ Pick the disease to be analyzed from the list of diseases in the selected dataset. | |
| """) | |
| st.subheader("β οΈ Warning:") | |
| st.write(""" | |
| - **π« Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only. | |
| - This demo is intended for **π¬ research purposes only** and should **β not be used for medical diagnoses**. | |
| - The modelβs responses may contain **<span style='color:#dc3545; font-weight:bold;'>π€ hallucinations or incorrect information</span>**. | |
| - Always consult a **<span style='color:#dc3545; font-weight:bold;'>π¨ββοΈ medical professional</span>** for accurate diagnosis and advice. | |
| """, unsafe_allow_html=True) | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| # Run Inference Button | |
| if st.button("Run Inference on Example", key="example"): | |
| if image is None: | |
| st.error("β Please select an example image first.") | |
| else: | |
| # Use the selected example's disease and definition for inference | |
| disease_choice = example_diseases[0] if example_diseases else "" | |
| definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, "")) | |
| # Generate the prompt for the model | |
| det_obj = f"{disease_choice} means {definition}." | |
| st.write(f"**Definition:** {definition}") | |
| prompt = f"Locate the phrases in the caption: {det_obj}." | |
| prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}" | |
| # Prepare the image and input | |
| np_image = np.array(image) | |
| inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE) | |
| with st.spinner("Processing... β³"): | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| num_beams=3, | |
| output_scores=True, # Make sure we get the scores/logits | |
| return_dict_in_generate=True # Ensures you get both sequences and scores in the output | |
| ) | |
| # Ensure transition_scores is properly extracted | |
| transition_scores = model.compute_transition_scores( | |
| outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False | |
| ) | |
| # Get the generated token IDs (ignoring the input tokens part) | |
| generated_ids = outputs.sequences | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| # Get input length | |
| input_length = inputs.input_ids.shape[1] | |
| generated_tokens = outputs.sequences | |
| # Calculate output length (number of generated tokens) | |
| output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1) | |
| # Get length penalty | |
| length_penalty = model.generation_config.length_penalty | |
| # Calculate total score for the generated sentence | |
| reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty) | |
| # Convert log-probability to probability (0-1 range) | |
| probabilities = np.exp(reconstructed_scores.cpu().numpy()) | |
| # Streamlit UI to display the result | |
| st.markdown(f"**π― Probability of the Results:** <span style='color:#28a745; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True) | |
| predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2]) | |
| detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2]) | |
| # Annotate the image with bounding boxes and labels | |
| bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection) | |
| image_with_predictions = label_annotator.annotate(image_with_predictions, detection) | |
| annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8)) | |
| # Display the original and result images side by side | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| st.image(image, caption=f"Original Image: {selected_example}", width=400) | |
| with col2: | |
| st.image(annotated_image, caption="Inference Results πΌοΈ", width=400) | |
| # Display the generated text | |
| st.write("**Generated Text:**", generated_text) | |
| # Upload Image section | |
| st.subheader("π€ Upload Your Own Image") | |
| col1, col2 = st.columns([1, 1]) | |
| with col1: | |
| dataset_choice = st.selectbox("Select Dataset π", options=list(dataset_options.keys())) | |
| disease_options = list(dataset_options[dataset_choice].keys()) | |
| with col2: | |
| disease_choice = st.selectbox("Select Disease π¦ ", options=disease_options) | |
| uploaded_file = st.file_uploader("Upload an Image", type=["png", "jpg", "jpeg"]) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| # Handle file upload | |
| if uploaded_file: | |
| image = Image.open(uploaded_file).convert("RGB") | |
| image = apply_transform(image) # Ensure the uploaded image is transformed correctly | |
| st.image(image, caption="Uploaded Image", width=400) | |
| # Let user select dataset and disease dynamically | |
| disease_choice = disease_choice if disease_choice else example_diseases[0] | |
| # Get Definition Priority: Dataset -> User Input | |
| definition = vindr_definitions.get(disease_choice, padchest_definitions.get(disease_choice, "")) | |
| if not definition: | |
| definition = st.text_input("Enter Definition Manually π", value="") | |
| with col2: | |
| # Instructions and warnings | |
| st.subheader("βοΈ Instructions to Get Started:") | |
| st.write(""" | |
| - **Run Inference**: Click the "Run Inference on Example" button to process the image and display the results. | |
| - **Choose an Example**: π Select an example image from the dataset to view its associated diseases. | |
| - **Upload Your Own Image**: π€ Upload an image of your choice to analyze it for diseases. | |
| - **Select Dataset**: π Choose between available datasets (Vindr or PadChest) for disease information. | |
| - **Select Disease**: π¦ Pick the disease to be analyzed from the list of diseases in the selected dataset. | |
| """) | |
| st.subheader("β οΈ Warning:") | |
| st.write(""" | |
| - **π« Please avoid uploading non-frontal chest X-ray images.** Our model has been specifically trained on **frontal chest X-ray images** only. | |
| - This demo is intended for **π¬ research purposes only** and should **β not be used for medical diagnoses**. | |
| - The modelβs responses may contain **<span style='color:#dc3545; font-weight:bold;'>π€ hallucinations or incorrect information</span>**. | |
| - Always consult a **<span style='color:#dc3545; font-weight:bold;'>π¨ββοΈ medical professional</span>** for accurate diagnosis and advice. | |
| """, unsafe_allow_html=True) | |
| st.markdown(""" | |
| <img src="//www.clustrmaps.com/map_v2.png?d=uM9v_RTadJ3hLvNbBSQ2PZ0KNPABbilkZgDyiXmuC0M&cl=ffffff" | |
| style="position:absolute;top:-9999px;left:-9999px;width:1px;height:1px;visibility:hidden;opacity:0;pointer-events:none;z-index:-1;display:none;" /> | |
| """, unsafe_allow_html=True) | |
| # Run inference after upload | |
| if st.button("Run Inference πββοΈ"): | |
| if image is None: | |
| st.error("β Please upload an image or select an example.") | |
| else: | |
| det_obj = f"{disease_choice} means {definition}." | |
| st.write(f"**Definition:** {definition}") | |
| # Construct Prompt with Disease Definition | |
| prompt = f"Locate the phrases in the caption: {det_obj}." | |
| prompt = f"<CAPTION_TO_PHRASE_GROUNDING>{prompt}" | |
| np_image = np.array(image) | |
| inputs = processor(text=[prompt], images=[np_image], return_tensors="pt", padding=True).to(DEVICE) | |
| with st.spinner("Processing... β³"): | |
| # generated_ids = model.generate(input_ids=inputs["input_ids"], pixel_values=inputs["pixel_values"], max_new_tokens=1024, num_beams=3) | |
| # generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| outputs = model.generate( | |
| input_ids=inputs["input_ids"], | |
| pixel_values=inputs["pixel_values"], | |
| max_new_tokens=1024, | |
| num_beams=3, | |
| output_scores=True, # Make sure we get the scores/logits | |
| return_dict_in_generate=True # Ensures you get both sequences and scores in the output | |
| ) | |
| transition_scores = model.compute_transition_scores( | |
| outputs.sequences, outputs.scores, outputs.beam_indices, normalize_logits=False | |
| ) | |
| # Get the generated token IDs (ignoring the input tokens part) | |
| generated_ids = outputs.sequences | |
| generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0] | |
| # Get input length | |
| input_length = inputs.input_ids.shape[1] | |
| # Extract generated tokens (ignoring the input tokens) | |
| # generated_tokens = outputs.sequences[:, input_length:] | |
| generated_tokens = outputs.sequences | |
| # Calculate output length (number of generated tokens) | |
| output_length = np.sum(transition_scores.cpu().numpy() < 0, axis=1) | |
| # Get length penalty | |
| length_penalty = model.generation_config.length_penalty | |
| # Calculate total score for the generated sentence | |
| reconstructed_scores = transition_scores.cpu().sum(axis=1) / (output_length**length_penalty) | |
| # Convert log-probability to probability (0-1 range) | |
| probabilities = np.exp(reconstructed_scores.cpu().numpy()) | |
| # Streamlit UI to display the result | |
| # st.write(f"**Probability of the Results (0-1):** {probabilities[0]:.4f}") | |
| st.markdown(f"**π― Probability of the Results:** <span style='color:green; font-size:24px; font-weight:bold;'>{probabilities[0] * 100:.2f}%</span>", unsafe_allow_html=True) | |
| predictions = processor.post_process_generation(generated_text, task="<CAPTION_TO_PHRASE_GROUNDING>", image_size=np_image.shape[:2]) | |
| detection = sv.Detections.from_lmm(sv.LMM.FLORENCE_2, predictions, resolution_wh=np_image.shape[:2]) | |
| bounding_box_annotator = sv.BoundingBoxAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| label_annotator = sv.LabelAnnotator(color_lookup=sv.ColorLookup.INDEX) | |
| image_with_predictions = bounding_box_annotator.annotate(np_image.copy(), detection) | |
| image_with_predictions = label_annotator.annotate(image_with_predictions, detection) | |
| annotated_image = Image.fromarray(image_with_predictions.astype(np.uint8)) | |
| # Create two columns to display the original and the results side by side | |
| col1, col2 = st.columns([1, 1]) | |
| # Left column for original image | |
| with col1: | |
| st.image(image, caption="Uploaded Image", width=400) | |
| # Right column for result image | |
| with col2: | |
| st.image(annotated_image, caption="Inference Results πΌοΈ", width=400) | |
| # Display the generated text | |
| st.write("**Generated Text:**", generated_text) | |