Update app.py
Browse files
app.py
CHANGED
|
@@ -3,14 +3,13 @@ import pandas as pd
|
|
| 3 |
from datetime import datetime
|
| 4 |
from typing import List, Tuple, Dict, Union
|
| 5 |
import gradio as gr
|
| 6 |
-
from concurrent.futures import ThreadPoolExecutor
|
| 7 |
|
| 8 |
# Constants
|
| 9 |
MAX_MODEL_TOKENS = 131072
|
| 10 |
MAX_NEW_TOKENS = 4096
|
| 11 |
MAX_CHUNK_TOKENS = 8192
|
| 12 |
PROMPT_OVERHEAD = 300
|
| 13 |
-
BATCH_SIZE = 3 #
|
| 14 |
|
| 15 |
# Paths
|
| 16 |
persistent_dir = "/data/hf_cache"
|
|
@@ -84,14 +83,14 @@ def init_agent() -> TxAgent:
|
|
| 84 |
agent.init_model()
|
| 85 |
return agent
|
| 86 |
|
| 87 |
-
#
|
| 88 |
-
def
|
| 89 |
-
results = [
|
| 90 |
-
|
| 91 |
-
def worker(index, batch):
|
| 92 |
prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
|
| 93 |
if estimate_tokens(prompt) > MAX_MODEL_TOKENS:
|
| 94 |
-
|
|
|
|
| 95 |
response = ""
|
| 96 |
try:
|
| 97 |
for r in agent.run_gradio_chat(
|
|
@@ -111,19 +110,9 @@ def analyze_parallel(agent, batch_chunks: List[List[str]], max_workers: int = 3)
|
|
| 111 |
response += m.content
|
| 112 |
elif hasattr(r, "content"):
|
| 113 |
response += r.content
|
| 114 |
-
|
| 115 |
except Exception as e:
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
| 119 |
-
futures = {executor.submit(worker, idx, batch): idx for idx, batch in enumerate(batch_chunks)}
|
| 120 |
-
for future in futures:
|
| 121 |
-
idx = futures[future]
|
| 122 |
-
try:
|
| 123 |
-
results[idx] = future.result()
|
| 124 |
-
except Exception as e:
|
| 125 |
-
results[idx] = f"β Error in batch {idx+1}: {str(e)}"
|
| 126 |
-
|
| 127 |
gc.collect()
|
| 128 |
return results
|
| 129 |
|
|
@@ -161,7 +150,7 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
|
|
| 161 |
batch_chunks = [chunks[i:i+BATCH_SIZE] for i in range(0, len(chunks), BATCH_SIZE)]
|
| 162 |
messages.append({"role": "assistant", "content": f"π Split into {len(batch_chunks)} batches. Analyzing..."})
|
| 163 |
|
| 164 |
-
chunk_results =
|
| 165 |
valid = [res for res in chunk_results if not res.startswith("β")]
|
| 166 |
|
| 167 |
if not valid:
|
|
@@ -181,15 +170,14 @@ def process_report(agent, file, messages: List[Dict[str, str]]) -> Tuple[List[Di
|
|
| 181 |
messages.append({"role": "assistant", "content": f"β Error: {str(e)}"})
|
| 182 |
return messages, None
|
| 183 |
|
| 184 |
-
|
| 185 |
def create_ui(agent):
|
| 186 |
with gr.Blocks(css="""
|
| 187 |
html, body, .gradio-container {
|
| 188 |
background-color: #0e1621;
|
| 189 |
color: #e0e0e0;
|
| 190 |
font-family: 'Inter', sans-serif;
|
| 191 |
-
margin: 0;
|
| 192 |
padding: 0;
|
|
|
|
| 193 |
}
|
| 194 |
h2, h3, h4 {
|
| 195 |
color: #89b4fa;
|
|
@@ -215,7 +203,6 @@ def create_ui(agent):
|
|
| 215 |
}
|
| 216 |
.gr-chatbot .message {
|
| 217 |
font-size: 16px;
|
| 218 |
-
line-height: 1.6;
|
| 219 |
padding: 12px 16px;
|
| 220 |
border-radius: 18px;
|
| 221 |
margin: 8px 0;
|
|
@@ -242,10 +229,8 @@ def create_ui(agent):
|
|
| 242 |
<h2>π CPS: Clinical Patient Support System</h2>
|
| 243 |
<p>CPS Assistant helps you analyze and summarize unstructured medical files using AI.</p>
|
| 244 |
""")
|
| 245 |
-
|
| 246 |
with gr.Column():
|
| 247 |
chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
|
| 248 |
-
|
| 249 |
upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
|
| 250 |
analyze = gr.Button("π§ Analyze", variant="primary")
|
| 251 |
download = gr.File(label="Download Report", visible=False, interactive=False)
|
|
@@ -268,4 +253,3 @@ if __name__ == "__main__":
|
|
| 268 |
except Exception as err:
|
| 269 |
print(f"Startup failed: {err}")
|
| 270 |
sys.exit(1)
|
| 271 |
-
|
|
|
|
| 3 |
from datetime import datetime
|
| 4 |
from typing import List, Tuple, Dict, Union
|
| 5 |
import gradio as gr
|
|
|
|
| 6 |
|
| 7 |
# Constants
|
| 8 |
MAX_MODEL_TOKENS = 131072
|
| 9 |
MAX_NEW_TOKENS = 4096
|
| 10 |
MAX_CHUNK_TOKENS = 8192
|
| 11 |
PROMPT_OVERHEAD = 300
|
| 12 |
+
BATCH_SIZE = 3 # group 3 chunks together
|
| 13 |
|
| 14 |
# Paths
|
| 15 |
persistent_dir = "/data/hf_cache"
|
|
|
|
| 83 |
agent.init_model()
|
| 84 |
return agent
|
| 85 |
|
| 86 |
+
# Serial processing (safe for vLLM)
|
| 87 |
+
def analyze_serial(agent, batch_chunks: List[List[str]]) -> List[str]:
|
| 88 |
+
results = []
|
| 89 |
+
for idx, batch in enumerate(batch_chunks):
|
|
|
|
| 90 |
prompt = "\n\n".join(build_prompt(chunk) for chunk in batch)
|
| 91 |
if estimate_tokens(prompt) > MAX_MODEL_TOKENS:
|
| 92 |
+
results.append(f"β Batch {idx+1} too long. Skipped.")
|
| 93 |
+
continue
|
| 94 |
response = ""
|
| 95 |
try:
|
| 96 |
for r in agent.run_gradio_chat(
|
|
|
|
| 110 |
response += m.content
|
| 111 |
elif hasattr(r, "content"):
|
| 112 |
response += r.content
|
| 113 |
+
results.append(clean_response(response))
|
| 114 |
except Exception as e:
|
| 115 |
+
results.append(f"β Error in batch {idx+1}: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
gc.collect()
|
| 117 |
return results
|
| 118 |
|
|
|
|
| 150 |
batch_chunks = [chunks[i:i+BATCH_SIZE] for i in range(0, len(chunks), BATCH_SIZE)]
|
| 151 |
messages.append({"role": "assistant", "content": f"π Split into {len(batch_chunks)} batches. Analyzing..."})
|
| 152 |
|
| 153 |
+
chunk_results = analyze_serial(agent, batch_chunks)
|
| 154 |
valid = [res for res in chunk_results if not res.startswith("β")]
|
| 155 |
|
| 156 |
if not valid:
|
|
|
|
| 170 |
messages.append({"role": "assistant", "content": f"β Error: {str(e)}"})
|
| 171 |
return messages, None
|
| 172 |
|
|
|
|
| 173 |
def create_ui(agent):
|
| 174 |
with gr.Blocks(css="""
|
| 175 |
html, body, .gradio-container {
|
| 176 |
background-color: #0e1621;
|
| 177 |
color: #e0e0e0;
|
| 178 |
font-family: 'Inter', sans-serif;
|
|
|
|
| 179 |
padding: 0;
|
| 180 |
+
margin: 0;
|
| 181 |
}
|
| 182 |
h2, h3, h4 {
|
| 183 |
color: #89b4fa;
|
|
|
|
| 203 |
}
|
| 204 |
.gr-chatbot .message {
|
| 205 |
font-size: 16px;
|
|
|
|
| 206 |
padding: 12px 16px;
|
| 207 |
border-radius: 18px;
|
| 208 |
margin: 8px 0;
|
|
|
|
| 229 |
<h2>π CPS: Clinical Patient Support System</h2>
|
| 230 |
<p>CPS Assistant helps you analyze and summarize unstructured medical files using AI.</p>
|
| 231 |
""")
|
|
|
|
| 232 |
with gr.Column():
|
| 233 |
chatbot = gr.Chatbot(label="CPS Assistant", height=700, type="messages")
|
|
|
|
| 234 |
upload = gr.File(label="Upload Medical File", file_types=[".xlsx"])
|
| 235 |
analyze = gr.Button("π§ Analyze", variant="primary")
|
| 236 |
download = gr.File(label="Download Report", visible=False, interactive=False)
|
|
|
|
| 253 |
except Exception as err:
|
| 254 |
print(f"Startup failed: {err}")
|
| 255 |
sys.exit(1)
|
|
|