Update app.py
Browse files
app.py
CHANGED
|
@@ -23,6 +23,7 @@ import tempfile
|
|
| 23 |
import subprocess
|
| 24 |
import warnings
|
| 25 |
from threading import Thread
|
|
|
|
| 26 |
|
| 27 |
# Environment setup
|
| 28 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
@@ -40,7 +41,6 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStream
|
|
| 40 |
# =============================================================================
|
| 41 |
# SAM-2 Alias Patch & Installer
|
| 42 |
# =============================================================================
|
| 43 |
-
# Alias sam_2 package to sam2 namespace
|
| 44 |
try:
|
| 45 |
import sam_2, importlib
|
| 46 |
sys.modules['sam2'] = sam_2
|
|
@@ -50,23 +50,20 @@ except ImportError:
|
|
| 50 |
pass
|
| 51 |
|
| 52 |
def check_and_install_sam2():
|
| 53 |
-
"""Ensure SAM-2 is installed and aliased as sam2."""
|
| 54 |
try:
|
| 55 |
from sam2.build_sam import build_sam2
|
| 56 |
return True
|
| 57 |
except ImportError:
|
| 58 |
-
# Clone repo
|
| 59 |
repo_dir = Path("segment-anything-2")
|
| 60 |
if not repo_dir.exists():
|
| 61 |
subprocess.run(["git","clone","https://github.com/facebookresearch/segment-anything-2.git"], check=True)
|
| 62 |
-
|
| 63 |
-
cwd = os.getcwd()
|
| 64 |
os.chdir(repo_dir)
|
| 65 |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
|
| 66 |
os.chdir(cwd)
|
| 67 |
-
# Re-alias
|
| 68 |
try:
|
| 69 |
-
import sam_2
|
|
|
|
| 70 |
sys.modules['sam2'] = sam_2
|
| 71 |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
|
| 72 |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
|
|
@@ -75,7 +72,7 @@ def check_and_install_sam2():
|
|
| 75 |
return False
|
| 76 |
|
| 77 |
SAM2_AVAILABLE = check_and_install_sam2()
|
| 78 |
-
|
| 79 |
if SAM2_AVAILABLE:
|
| 80 |
from sam2.build_sam import build_sam2
|
| 81 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
@@ -126,9 +123,7 @@ class MedicalVLMAgent:
|
|
| 126 |
user_cont.append({"type":"text","text": text or ""})
|
| 127 |
msgs.append({"role":"user","content":user_cont})
|
| 128 |
prompt = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 129 |
-
|
| 130 |
-
inputs = self.processor(text=[prompt], images=img_in, videos=vid_in,
|
| 131 |
-
padding=True, return_tensors='pt').to(self.device)
|
| 132 |
out = self.model.generate(**inputs, max_new_tokens=128)
|
| 133 |
resp = out[0][inputs.input_ids.shape[1]:]
|
| 134 |
return self.processor.decode(resp, skip_special_tokens=True).strip()
|
|
@@ -139,15 +134,14 @@ class MedicalVLMAgent:
|
|
| 139 |
_sam2_model, _mask_generator = (None, None)
|
| 140 |
if SAM2_AVAILABLE:
|
| 141 |
try:
|
| 142 |
-
# Initialize model
|
| 143 |
CKPT="checkpoints/sam2.1_hiera_large.pt"; CFG="configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 144 |
os.chdir("segment-anything-2/sam2/sam2")
|
| 145 |
_sam2_model = build_sam2(CFG, CKPT, device=get_device(), apply_postprocessing=False)
|
| 146 |
_mask_generator = SAM2AutomaticMaskGenerator(_sam2_model)
|
| 147 |
-
except Exception:
|
|
|
|
| 148 |
_mask_generator = None
|
| 149 |
|
| 150 |
-
|
| 151 |
def segmentation_interface(image):
|
| 152 |
if image is None: return None, "Upload an image"
|
| 153 |
if not _mask_generator: return None, "SAM-2 unavailable"
|
|
@@ -157,7 +151,7 @@ def segmentation_interface(image):
|
|
| 157 |
for ann in sorted(anns, key=lambda x: x['area'], reverse=True):
|
| 158 |
m = ann['segmentation']; color=np.random.randint(0,255,3)
|
| 159 |
overlay[m] = (overlay[m]*0.5 + color*0.5).astype(np.uint8)
|
| 160 |
-
return Image.fromarray(overlay), f"{len(anns)} masks"
|
| 161 |
|
| 162 |
# =============================================================================
|
| 163 |
# Fallback segmentation
|
|
@@ -176,9 +170,7 @@ def fallback_segmentation(image):
|
|
| 176 |
# =============================================================================
|
| 177 |
try:
|
| 178 |
chex_tok = AutoTokenizer.from_pretrained("StanfordAIMI/CheXagent-2-3b", trust_remote_code=True)
|
| 179 |
-
chex_model = AutoModelForCausalLM.from_pretrained(
|
| 180 |
-
"StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True
|
| 181 |
-
)
|
| 182 |
if torch.cuda.is_available(): chex_model = chex_model.half()
|
| 183 |
chex_model.eval(); CHEX_AVAILABLE=True
|
| 184 |
except Exception:
|
|
@@ -188,13 +180,11 @@ except Exception:
|
|
| 188 |
def report_generation(im1, im2):
|
| 189 |
if not CHEX_AVAILABLE: yield "CheXagent unavailable"; return
|
| 190 |
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True)
|
| 191 |
-
|
| 192 |
-
yield "Report not implemented in snippet"
|
| 193 |
|
| 194 |
@torch.no_grad()
|
| 195 |
def phrase_grounding(image, prompt):
|
| 196 |
if not CHEX_AVAILABLE: return "CheXagent unavailable", None
|
| 197 |
-
# simple box
|
| 198 |
w,h=image.size; draw=ImageDraw.Draw(image)
|
| 199 |
draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3)
|
| 200 |
return prompt, image
|
|
@@ -202,36 +192,27 @@ def phrase_grounding(image, prompt):
|
|
| 202 |
# =============================================================================
|
| 203 |
# Gradio UI
|
| 204 |
# =============================================================================
|
|
|
|
| 205 |
def create_ui():
|
| 206 |
-
# Load agents
|
| 207 |
try:
|
| 208 |
-
|
| 209 |
-
|
| 210 |
except:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
with gr.Blocks() as demo:
|
| 214 |
gr.Markdown("# Medical AI Assistant")
|
| 215 |
-
gr.Markdown(f"- Qwen
|
| 216 |
-
f"- SAM-2: {'β
' if _mask_generator else 'β'} "
|
| 217 |
-
f"- CheXagent: {'β
' if CHEX_AVAILABLE else 'β'}")
|
| 218 |
with gr.Tab("Medical Q&A"):
|
| 219 |
-
txt=gr.Textbox(); img=gr.Image(type='pil'); out=gr.Textbox();
|
| 220 |
-
btn.click(med_agent.run, [txt,img], out)
|
| 221 |
with gr.Tab("Segmentation"):
|
| 222 |
-
|
| 223 |
-
if _mask_generator: fn=segmentation_interface
|
| 224 |
-
else: fn=fallback_segmentation
|
| 225 |
-
gr.Button("Segment").click(fn, segin, [segout, stat])
|
| 226 |
with gr.Tab("CheXagent Report"):
|
| 227 |
-
c1=gr.Image(type='pil');
|
| 228 |
-
gr.Interface(fn=report_generation, inputs=[c1,c2], outputs=rout, live=True).render()
|
| 229 |
with gr.Tab("CheXagent Grounding"):
|
| 230 |
-
gi=gr.Image(type='pil'); gp=gr.Textbox(); gout=gr.Textbox(); goimg=gr.Image()
|
| 231 |
-
gr.Interface(fn=phrase_grounding, inputs=[gi,gp], outputs=[gout,goimg]).render()
|
| 232 |
return demo
|
| 233 |
|
| 234 |
if __name__ == "__main__":
|
| 235 |
-
ui
|
| 236 |
-
|
| 237 |
|
|
|
|
| 23 |
import subprocess
|
| 24 |
import warnings
|
| 25 |
from threading import Thread
|
| 26 |
+
from pathlib import Path
|
| 27 |
|
| 28 |
# Environment setup
|
| 29 |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
|
|
|
| 41 |
# =============================================================================
|
| 42 |
# SAM-2 Alias Patch & Installer
|
| 43 |
# =============================================================================
|
|
|
|
| 44 |
try:
|
| 45 |
import sam_2, importlib
|
| 46 |
sys.modules['sam2'] = sam_2
|
|
|
|
| 50 |
pass
|
| 51 |
|
| 52 |
def check_and_install_sam2():
|
|
|
|
| 53 |
try:
|
| 54 |
from sam2.build_sam import build_sam2
|
| 55 |
return True
|
| 56 |
except ImportError:
|
|
|
|
| 57 |
repo_dir = Path("segment-anything-2")
|
| 58 |
if not repo_dir.exists():
|
| 59 |
subprocess.run(["git","clone","https://github.com/facebookresearch/segment-anything-2.git"], check=True)
|
| 60 |
+
cwd = Path.cwd()
|
|
|
|
| 61 |
os.chdir(repo_dir)
|
| 62 |
subprocess.run([sys.executable, "-m", "pip", "install", "-e", "."], check=True)
|
| 63 |
os.chdir(cwd)
|
|
|
|
| 64 |
try:
|
| 65 |
+
import sam_2
|
| 66 |
+
importlib.reload(sam_2)
|
| 67 |
sys.modules['sam2'] = sam_2
|
| 68 |
for sub in ['build_sam','automatic_mask_generator','modeling.sam2_base']:
|
| 69 |
sys.modules[f'sam2.{sub}'] = importlib.import_module(f'sam_2.{sub}')
|
|
|
|
| 72 |
return False
|
| 73 |
|
| 74 |
SAM2_AVAILABLE = check_and_install_sam2()
|
| 75 |
+
print(f"SAM-2 Available: {SAM2_AVAILABLE}")
|
| 76 |
if SAM2_AVAILABLE:
|
| 77 |
from sam2.build_sam import build_sam2
|
| 78 |
from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
|
|
|
|
| 123 |
user_cont.append({"type":"text","text": text or ""})
|
| 124 |
msgs.append({"role":"user","content":user_cont})
|
| 125 |
prompt = self.processor.apply_chat_template(msgs, tokenize=False, add_generation_prompt=True)
|
| 126 |
+
inputs = self.processor(text=[prompt], images=[], videos=[], padding=True, return_tensors='pt').to(self.device)
|
|
|
|
|
|
|
| 127 |
out = self.model.generate(**inputs, max_new_tokens=128)
|
| 128 |
resp = out[0][inputs.input_ids.shape[1]:]
|
| 129 |
return self.processor.decode(resp, skip_special_tokens=True).strip()
|
|
|
|
| 134 |
_sam2_model, _mask_generator = (None, None)
|
| 135 |
if SAM2_AVAILABLE:
|
| 136 |
try:
|
|
|
|
| 137 |
CKPT="checkpoints/sam2.1_hiera_large.pt"; CFG="configs/sam2.1/sam2.1_hiera_l.yaml"
|
| 138 |
os.chdir("segment-anything-2/sam2/sam2")
|
| 139 |
_sam2_model = build_sam2(CFG, CKPT, device=get_device(), apply_postprocessing=False)
|
| 140 |
_mask_generator = SAM2AutomaticMaskGenerator(_sam2_model)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"SAM-2 init error: {e}")
|
| 143 |
_mask_generator = None
|
| 144 |
|
|
|
|
| 145 |
def segmentation_interface(image):
|
| 146 |
if image is None: return None, "Upload an image"
|
| 147 |
if not _mask_generator: return None, "SAM-2 unavailable"
|
|
|
|
| 151 |
for ann in sorted(anns, key=lambda x: x['area'], reverse=True):
|
| 152 |
m = ann['segmentation']; color=np.random.randint(0,255,3)
|
| 153 |
overlay[m] = (overlay[m]*0.5 + color*0.5).astype(np.uint8)
|
| 154 |
+
return Image.fromarray(overlay), f"{len(anns)} masks found"
|
| 155 |
|
| 156 |
# =============================================================================
|
| 157 |
# Fallback segmentation
|
|
|
|
| 170 |
# =============================================================================
|
| 171 |
try:
|
| 172 |
chex_tok = AutoTokenizer.from_pretrained("StanfordAIMI/CheXagent-2-3b", trust_remote_code=True)
|
| 173 |
+
chex_model = AutoModelForCausalLM.from_pretrained("StanfordAIMI/CheXagent-2-3b", device_map='auto', trust_remote_code=True)
|
|
|
|
|
|
|
| 174 |
if torch.cuda.is_available(): chex_model = chex_model.half()
|
| 175 |
chex_model.eval(); CHEX_AVAILABLE=True
|
| 176 |
except Exception:
|
|
|
|
| 180 |
def report_generation(im1, im2):
|
| 181 |
if not CHEX_AVAILABLE: yield "CheXagent unavailable"; return
|
| 182 |
streamer = TextIteratorStreamer(chex_tok, skip_prompt=True)
|
| 183 |
+
yield "Report streaming not fully implemented"
|
|
|
|
| 184 |
|
| 185 |
@torch.no_grad()
|
| 186 |
def phrase_grounding(image, prompt):
|
| 187 |
if not CHEX_AVAILABLE: return "CheXagent unavailable", None
|
|
|
|
| 188 |
w,h=image.size; draw=ImageDraw.Draw(image)
|
| 189 |
draw.rectangle([(w*0.25,h*0.25),(w*0.75,h*0.75)], outline='red', width=3)
|
| 190 |
return prompt, image
|
|
|
|
| 192 |
# =============================================================================
|
| 193 |
# Gradio UI
|
| 194 |
# =============================================================================
|
| 195 |
+
|
| 196 |
def create_ui():
|
|
|
|
| 197 |
try:
|
| 198 |
+
m, p, d = load_qwen_model_and_processor()
|
| 199 |
+
med = MedicalVLMAgent(m,p,d); QW=True
|
| 200 |
except:
|
| 201 |
+
QW=False; med=None
|
|
|
|
| 202 |
with gr.Blocks() as demo:
|
| 203 |
gr.Markdown("# Medical AI Assistant")
|
| 204 |
+
gr.Markdown(f"- Qwen: {'β
' if QW else 'β'} - SAM-2: {'β
' if _mask_generator else 'β'} - CheX: {'β
' if CHEX_AVAILABLE else 'β'}")
|
|
|
|
|
|
|
| 205 |
with gr.Tab("Medical Q&A"):
|
| 206 |
+
txt=gr.Textbox(); img=gr.Image(type='pil'); out=gr.Textbox(); gr.Button("Ask").click(med.run,[txt,img],out)
|
|
|
|
| 207 |
with gr.Tab("Segmentation"):
|
| 208 |
+
seg=gr.Image(type='pil'); so=gr.Image(); ss=gr.Textbox(); fn=segmentation_interface if _mask_generator else fallback_segmentation; gr.Button("Segment").click(fn,seg,[so,ss])
|
|
|
|
|
|
|
|
|
|
| 209 |
with gr.Tab("CheXagent Report"):
|
| 210 |
+
c1=gr.Image(type='pil');c2=gr.Image(type='pil'); rout=gr.Markdown(); gr.Interface(report_generation,[c1,c2],rout,live=True).render()
|
|
|
|
| 211 |
with gr.Tab("CheXagent Grounding"):
|
| 212 |
+
gi=gr.Image(type='pil'); gp=gr.Textbox(); gout=gr.Textbox(); goimg=gr.Image(); gr.Interface(phrase_grounding,[gi,gp],[gout,goimg]).render()
|
|
|
|
| 213 |
return demo
|
| 214 |
|
| 215 |
if __name__ == "__main__":
|
| 216 |
+
ui=create_ui(); ui.launch(server_name='0.0.0.0',server_port=7860,share=True)
|
| 217 |
+
|
| 218 |
|