ImedHa commited on
Commit
02bb97a
·
verified ·
1 Parent(s): 505fb2e

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +128 -0
  2. requirements.txt +6 -3
app.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import streamlit as st
3
+ from PIL import Image
4
+ import torch
5
+ import os
6
+ from io import StringIO
7
+ import sys
8
+
9
+ # --- TorchDynamo Fix for Unsloth/MedGemma ---
10
+ import torch._dynamo
11
+ torch._dynamo.config.capture_scalar_outputs = True
12
+ torch.compiler.disable()
13
+
14
+ # --- Dependency Handling ---
15
+ try:
16
+ from unsloth import FastVisionModel
17
+ from transformers import TextStreamer
18
+ except ImportError as e:
19
+ st.error(f"A required library is not installed. Please install dependencies. Error: {e}")
20
+ st.stop()
21
+
22
+ @st.cache_resource
23
+ def load_medgemma_model():
24
+ """Loads the MedGemma vision-language model in eager mode."""
25
+ try:
26
+ model, processor = FastVisionModel.from_pretrained(
27
+ "fiqqy/MedGemma-MM-OR-FT10",
28
+ load_in_4bit=False,
29
+ use_gradient_checkpointing="unsloth",
30
+ )
31
+ return model, processor
32
+ except Exception as e:
33
+ st.error(f"Error loading MedGemma model: {e}")
34
+ return None, None
35
+
36
+ def run_captioning(medgemma_model, processor, frames, instruction):
37
+ """Runs MedGemma inference using 3 frames and an instruction."""
38
+ st.write("Preparing inputs for MedGemma...")
39
+ images = [f.convert("RGB") for f in frames]
40
+ messages = [
41
+ {"role": "user", "content": [
42
+ {"type": "image"}, {"type": "image"}, {"type": "image"},
43
+ {"type": "text", "text": instruction},
44
+ ]},
45
+ ]
46
+ input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
47
+ device = "cuda" if torch.cuda.is_available() else "cpu"
48
+ inputs = processor(
49
+ images, input_text, add_special_tokens=False, return_tensors="pt",
50
+ ).to(device)
51
+
52
+ text_streamer = TextStreamer(processor, skip_prompt=True)
53
+ old_stdout = sys.stdout
54
+ sys.stdout = captured_output = StringIO()
55
+
56
+ st.write("Running MedGemma Analysis...")
57
+ torch._dynamo.disable()
58
+ medgemma_model.generate(
59
+ **inputs, streamer=text_streamer, max_new_tokens=768,
60
+ use_cache=True, temperature=1.0, top_p=0.95, top_k=64
61
+ )
62
+
63
+ sys.stdout = old_stdout
64
+ result = captured_output.getvalue()
65
+ return result
66
+
67
+ def show():
68
+ """Main function to render the Streamlit UI."""
69
+ st.title("MedGemma Scene Analysis System")
70
+ st.write("A system to test MedGemma vision-language captioning model.")
71
+
72
+ st.header("1. Load MedGemma Model")
73
+ if "medgemma_model" not in st.session_state:
74
+ st.session_state.medgemma_model, st.session_state.processor = None, None
75
+ if st.button("Load MedGemma Model"):
76
+ with st.spinner("Loading MedGemma... This can take several minutes."):
77
+ st.session_state.medgemma_model, st.session_state.processor = load_medgemma_model()
78
+
79
+ if st.session_state.get("medgemma_model") and st.session_state.get("processor"):
80
+ st.success("MedGemma model is loaded.")
81
+ else:
82
+ st.warning("MedGemma model is not loaded.")
83
+
84
+ st.header("2. Upload Data")
85
+ st.subheader("Upload Three Sequential Surgical Video Frames")
86
+ col1, col2, col3 = st.columns(3)
87
+ uploaded_files = [
88
+ col1.file_uploader("Upload Frame 1", type=["png", "jpg", "jpeg"], key="frame1"),
89
+ col2.file_uploader("Upload Frame 2", type=["png", "jpg", "jpeg"], key="frame2"),
90
+ col3.file_uploader("Upload Frame 3", type=["png", "jpg", "jpeg"], key="frame3")
91
+ ]
92
+ frames = [Image.open(f) for f in uploaded_files if f is not None]
93
+
94
+ display_size = (256, 256)
95
+ if len(frames) == 3:
96
+ st.success("All three frames have been uploaded successfully.")
97
+ img_cols = st.columns(3)
98
+ for i, frame in enumerate(frames):
99
+ img_cols[i].image(frame.resize(display_size), caption=f"Frame {i+1}", use_container_width=True)
100
+ else:
101
+ st.info("Please upload all three frames to proceed.")
102
+
103
+ st.header("3. Generate Scene Analysis")
104
+ instruction_prompt = st.text_area(
105
+ "Enter your custom instruction prompt:",
106
+ "Provide a detailed summary of the surgical action, noting the instruments used and their interactions."
107
+ )
108
+
109
+ can_run_analysis = (
110
+ st.session_state.get("medgemma_model") is not None and
111
+ len(frames) == 3 and
112
+ bool(instruction_prompt)
113
+ )
114
+
115
+ if st.button("Run Analysis", disabled=not can_run_analysis):
116
+ with st.spinner("Running MedGemma analysis... This may take a moment."):
117
+ result = run_captioning(
118
+ st.session_state.medgemma_model, st.session_state.processor,
119
+ frames, instruction_prompt
120
+ )
121
+ st.subheader("Analysis Result")
122
+ st.write(result)
123
+
124
+ if not can_run_analysis:
125
+ st.warning("Please ensure the MedGemma model is loaded, three frames are uploaded, and a prompt is provided.")
126
+
127
+ if __name__ == "__main__":
128
+ show()
requirements.txt CHANGED
@@ -1,3 +1,6 @@
1
- altair
2
- pandas
3
- streamlit
 
 
 
 
1
+ # Only MedGemma dependencies required
2
+ streamlit
3
+ Pillow
4
+ torch
5
+ unsloth
6
+ transformers