File size: 6,039 Bytes
9c1d8c1
012f71e
2744429
 
 
 
 
 
6d24703
 
 
 
 
 
 
 
 
9c1d8c1
 
 
 
 
 
45aa061
 
9c1d8c1
86a7ed9
 
 
5aa29ec
af1845c
903717c
 
45aa061
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
903717c
 
 
 
45aa061
 
7953ca8
 
 
 
903717c
 
 
 
45aa061
903717c
 
45aa061
903717c
 
 
 
45aa061
 
903717c
 
 
 
 
45aa061
903717c
 
 
45aa061
903717c
 
 
 
45aa061
 
 
 
 
 
 
 
903717c
 
 
 
 
 
 
 
 
 
 
 
45aa061
903717c
 
 
 
 
 
 
 
 
45aa061
 
 
 
 
 
 
903717c
45aa061
 
903717c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os

# Redirect cache directories to /tmp (a writable directory on Spaces)
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface"
os.environ["HF_HOME"] = "/tmp/huggingface"
os.environ["XDG_CACHE_HOME"] = "/tmp"
os.environ["STREAMLIT_HOME"] = "/tmp"

# Streamlit requires a writable .streamlit config dir β€” create manually
os.makedirs("/tmp/.streamlit", exist_ok=True)
with open("/tmp/.streamlit/config.toml", "w") as f:
    f.write("[general]\n")
    f.write("cachePath = '/tmp'\n")

# Point to this config
os.environ["STREAMLIT_CONFIG_FILE"] = "/tmp/.streamlit/config.toml"

import streamlit as st
import torch, transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from PIL import Image
from torchvision import transforms
from io import BytesIO
from pathlib import Path
import pandas as pd

# streamlit_config_dir = "/tmp/.streamlit"
# st.sidebar.write("Streamlit config dir exists:", os.path.exists(streamlit_config_dir))
# st.sidebar.write("Files:", os.listdir(streamlit_config_dir))

torch.classes.__path__ = []
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# load dataframe
data = {
    "path": [
        'test/s55512076.jpg',
        'test/s55786650.jpg',
        'test/s56188631.jpg',
        'test/s53690114.jpg',
        'test/s52070116.jpg'],

    "text": ['Comparison is made to prior study performed a day earlier. Lines and tubes are in unchanged standard position. Multifocal consolidations in the right upper and lower lobes bilaterally left greater than right are unchanged. Severe cardiomegaly is stable. There are no new lung abnormalities. Probably small right pleural effusion is unchanged.',
            'As compared to the previous radiograph, there is no relevant change. The monitoring and support devices are constant. Low lung volumes, borderline size of the cardiac silhouette. Mild pulmonary edema. Moderate retrocardiac atelectasis. No evidence of pneumonia.',
            'AP chest compared to ___ through ___. Elevation of the right lung base and hemidiaphragm has been pronounced since at least ___, accounting for atelectasis at the lung base. The right upper lung and the entire left lung are clear and the left lung is hyperinflated suggesting airway obstruction or emphysema. Heart is normal size. There is no pneumonia or pulmonary edema. No pleural effusion or pneumothorax.',
            'Compared to prior study there is no significant interval change.',
            'In comparison to prior radiograph of 1 day earlier, there has been improved aeration at both lung bases. No other relevant change since recent study.'],
}

# prepare data
mimic_df_test = pd.DataFrame.from_dict(data)

def load_images(path):
  img = Image.open(path)
  img = img.convert('RGB')
  return img

@st.cache_resource
def load_caption_model():   
    # load medicap
    ckpt_name = 'aehrc/medicap'
    
    local_folder = "model2/"
    # if os.path.exists(local_folder):
    medicap = transformers.AutoModel.from_pretrained(local_folder, trust_remote_code=True)
    # else:
        # medicap = transformers.AutoModel.from_pretrained(ckpt_name, trust_remote_code=True)
    medicap = medicap.to(device)
    medicap.eval()

    # transform image 
    medicap_transforms = transformers.AutoFeatureExtractor.from_pretrained(ckpt_name)

    # tokenizer
    medicap_tokenizer = transformers.GPT2Tokenizer.from_pretrained(ckpt_name)

    return medicap, medicap_transforms, medicap_tokenizer

def generate_image_caption(image, model, transformer, tokenizer):
    image = transformer(image, return_tensors="pt")
    image = image["pixel_values"]
    outputs = model.generate(
        pixel_values=image.to(device),
        bos_token_id=tokenizer.bos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        pad_token_id=tokenizer.pad_token_id,
        max_length=128,
        num_beams=4,
        output_attentions=False
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=True)

@st.cache_resource
def load_qa_model():
    model_name = "microsoft/BioGPT-Large-PubMedQA"

    local_folder = "BioGPT-Large-PubMedQA/"
    if os.path.exists(local_folder):
        biogpt_tokenizer = AutoTokenizer.from_pretrained(local_folder)
        biogpt = AutoModelForCausalLM.from_pretrained(local_folder)
    else:
        biogpt_tokenizer = AutoTokenizer.from_pretrained(model_name)
        biogpt = AutoModelForCausalLM.from_pretrained(model_name)
    biogpt = biogpt.to(device)
    biogpt.eval()

    return biogpt, biogpt_tokenizer

def generate_answer(description, question, model, tokenizer):
    prompt = f"question: {question} context: {description}"
    new_input_ids = tokenizer.encode(prompt, return_tensors='pt').to(device)
    input_ids = new_input_ids

    generated_output = model.generate(
        input_ids,
        max_new_tokens=128,  # Max new tokens for the bot's response
    )

    response = tokenizer.decode(generated_output[0], skip_special_tokens=True)

    return response

st.set_page_config(page_title="Image Caption + QA", layout="centered")
st.title("πŸ–ΌοΈ Caption-Based Question Answering")

# Dropdown list
options = range(len(mimic_df_test))
choice = st.selectbox("Choose an action:", options)
if choice is not None:
    data = mimic_df_test.iloc[choice]
    label = data['text']
    img = Image.open(Path(data['path']))
    st.image(img)
    st.subheader("πŸ“ Original Description")
    st.info(label)

    # image description
    medicap, medicap_transforms, medicap_tokenizer = load_caption_model()
    caption = generate_image_caption(img, medicap, medicap_transforms, medicap_tokenizer)

    st.subheader("πŸ“ Generated Description")
    st.info(caption)

    # vqa
    st.markdown("---")

    st.subheader("❓ Ask a Question About the Image")
    question = st.text_input("Type your question")

    if question:
        biogpt, biogpt_tokenizer = load_qa_model()
        response = generate_answer(caption, question, biogpt, biogpt_tokenizer)
        st.success(f"{response}")

else:
    st.info("Please upload an image file.")