Anirudh Balaraman commited on
Commit
2a68513
·
1 Parent(s): 8846550

remove streamlit code

Browse files
Files changed (1) hide show
  1. app.py +0 -151
app.py DELETED
@@ -1,151 +0,0 @@
1
- import streamlit as st
2
- import subprocess
3
- import os
4
- import shutil
5
- from huggingface_hub import hf_hub_download
6
-
7
- REPO_ID = "anirudh0410/WSAttention-Prostate"
8
- FILENAMES = ["pirads.pt", "prostate_segmentation_model.pt", "cspca_model.pth"]
9
-
10
- @st.cache_resource
11
- def download_all_models():
12
- # 1. Ensure the 'models' directory exists
13
- models_dir = os.path.join(os.getcwd(), 'models')
14
- os.makedirs(models_dir, exist_ok=True) # <--- THIS FIXES YOUR ERROR
15
-
16
- for filename in FILENAMES:
17
- try:
18
- # 2. Download from Hugging Face (to cache)
19
- cached_path = hf_hub_download(repo_id=REPO_ID, filename=filename)
20
-
21
- # 3. Define where we want it to live locally
22
- destination_path = os.path.join(models_dir, filename)
23
-
24
- # 4. Copy only if it's not already there
25
- if not os.path.exists(destination_path):
26
- shutil.copy(cached_path, destination_path)
27
-
28
- except Exception as e:
29
- st.error(f"Failed to download {filename}: {e}")
30
- st.stop()
31
-
32
- # --- TRIGGER THE DOWNLOAD STARTUP ---
33
- with st.spinner("Downloading model weights..."):
34
- download_all_models()
35
- st.success("Models ready!")
36
-
37
- # --- CONFIGURATION ---
38
- # Base paths
39
- BASE_DIR = os.getcwd()
40
- INPUT_BASE = os.path.join(BASE_DIR, "temp_data" )
41
- OUTPUT_DIR = os.path.join(BASE_DIR, "temp_data", "processed")
42
-
43
- # Create specific sub-directories for each input type
44
- # This ensures we pass a clean directory path to your script
45
- T2_DIR = os.path.join(INPUT_BASE, "t2")
46
- ADC_DIR = os.path.join(INPUT_BASE, "adc")
47
- DWI_DIR = os.path.join(INPUT_BASE, "dwi")
48
-
49
- # Ensure all folders exist
50
- if os.path.exists(INPUT_BASE):
51
- shutil.rmtree(INPUT_BASE)
52
- for path in [T2_DIR, ADC_DIR, DWI_DIR, OUTPUT_DIR]:
53
- os.makedirs(path, exist_ok=True)
54
-
55
- st.title("Model Inference")
56
- st.markdown("### Upload your T2W, ADC, and DWI scans")
57
-
58
- # --- 1. UI: THREE UPLOADERS ---
59
- col1, col2, col3 = st.columns(3)
60
-
61
- with col1:
62
- t2_file = st.file_uploader("Upload T2W (NRRD)", type=["nrrd"])
63
- with col2:
64
- adc_file = st.file_uploader("Upload ADC (NRRD)", type=["nrrd"])
65
- with col3:
66
- dwi_file = st.file_uploader("Upload DWI (NRRD)", type=["nrrd"])
67
-
68
- # --- 2. EXECUTION LOGIC ---
69
- if t2_file and adc_file and dwi_file:
70
- st.success("Files ready.")
71
-
72
- if st.button("Run Inference"):
73
- # --- A. CLEANUP & SAVE ---
74
- # Clear old files to prevent mixing previous runs
75
- # (Optional but recommended for a clean state)
76
- for folder in [T2_DIR, ADC_DIR, DWI_DIR, OUTPUT_DIR]:
77
- for f in os.listdir(folder):
78
- os.remove(os.path.join(folder, f))
79
-
80
- # Save T2
81
- # We save it inside the T2_DIR folder
82
- with open(os.path.join(T2_DIR, t2_file.name), "wb") as f:
83
- shutil.copyfileobj(t2_file, f)
84
-
85
- # Save ADC
86
- with open(os.path.join(ADC_DIR, adc_file.name), "wb") as f:
87
- shutil.copyfileobj(adc_file, f)
88
-
89
- # Save DWI
90
- with open(os.path.join(DWI_DIR, dwi_file.name), "wb") as f:
91
- shutil.copyfileobj(dwi_file, f)
92
-
93
- st.write("Files saved. Starting pipeline...")
94
-
95
- # --- B. CONSTRUCT COMMAND ---
96
- # We pass the FOLDER paths, not file paths, matching your argument names
97
- command = [
98
- "python", "run_inference.py",
99
- "--t2_dir", T2_DIR,
100
- "--dwi_dir", DWI_DIR,
101
- "--adc_dir", ADC_DIR,
102
- "--output_dir", OUTPUT_DIR,
103
- "--project_dir", BASE_DIR
104
- ]
105
-
106
- # DEBUG: Show the exact command being run (helpful for troubleshooting)
107
- st.code(" ".join(command), language="bash")
108
-
109
- # --- C. RUN SCRIPT ---
110
- with st.spinner("Running Inference... (This may take a moment)"):
111
- try:
112
- # Run the script and capture output
113
- result = subprocess.run(
114
- command,
115
- capture_output=True,
116
- text=True,
117
- check=True
118
- )
119
-
120
- st.success("Pipeline Execution Successful!")
121
-
122
- # Show Logs
123
- with st.expander("View Execution Logs"):
124
- st.code(result.stdout)
125
-
126
- # --- D. SHOW OUTPUT FILES ---
127
- st.subheader("Results & Downloads")
128
-
129
- # List everything in the output directory
130
- if os.path.exists(OUTPUT_DIR):
131
- output_files = os.listdir(OUTPUT_DIR)
132
-
133
- if output_files:
134
- for file_name in output_files:
135
- file_path = os.path.join(OUTPUT_DIR, file_name)
136
-
137
- # Skip directories, show download buttons for files
138
- if os.path.isfile(file_path):
139
- with open(file_path, "rb") as f:
140
- st.download_button(
141
- label=f"Download {file_name}",
142
- data=f,
143
- file_name=file_name
144
- )
145
- else:
146
- st.warning("Script finished but no files were found in output_dir.")
147
-
148
- except subprocess.CalledProcessError as e:
149
- st.error("Script Execution Failed.")
150
- st.error("Error Output:")
151
- st.code(e.stderr)