fatimaxa's picture
Update backend/api.py
cdc8dfa verified
import sys
import os
import time
sys.path.insert(0, os.path.dirname(__file__))
from session_management.session_manager import SessionManager
from intent_classification.intent_router_ml import route_intent
from imaging.image_processing import cnn_model
from symptom_extraction.symptom_extractor import extract_symptoms
from semantic_search.retrieve import calculate_risk, get_final_symptom_list, context_and_name, get_sources, get_dynamic_threshold
from session_management.ambiguity import check_for_vague, VAGUE_SYMPTOMS
from explanation_generation.generate_explanation import gen_explanation
from explanation_generation.follow_up_handler import handle_follow_up_question
IMAGE_PATH = "demo_images/img6.jpg"
MIN_EVIDENCE = 3
################################################################################################################################################
### INTENT OPTIONS
OPTION_TO_INTENT = {
"Add new evidence": "PATIENT_EVIDENCE_QUERY",
"Ask for explanation": "FOLLOW_UP_EXPLANATION",
"Request source": "SOURCE_REQUEST",
"General help": "HELP_OR_OTHER",
}
### HELPER FUNCTIONS
def vague_clarification_response(vague_symptom: str, current_evidence: list) -> dict:
vague_evidence = VAGUE_SYMPTOMS[vague_symptom.lower()]
options = [
{"key": k, "label": v}
for k, v in vague_evidence["options"].items()
]
return {
"type": "VAGUE_CLARIFICATION",
"vague_symptom": vague_symptom,
"question": vague_evidence["question"],
"options": options,
"pending_evidence": list(current_evidence),
}
def calculate_risk_with_formula(vector_search_output: list) -> tuple[str, str]:
risk_level = calculate_risk(vector_search_output)
risk_formula = write_risk_formula()
return risk_level, risk_formula
def help(text: str) -> str:
return "Go to the app's help page for more info."
def check_sufficient_evidence(session_manager: SessionManager, sid: int) -> dict | None:
validated = session_manager.get_validated_evidence(sid)
total = (
len(validated.get("symptoms", [])) +
len(validated.get("risk_factors", [])) +
len(validated.get("imaging_detections", []))
)
if total < MIN_EVIDENCE:
return {
"type": "LOW_EVIDENCE_WARNING",
"total_evidence": total,
"min_required": MIN_EVIDENCE,
"message": (
f"Only {total} piece(s) of evidence recorded, which is below the "
f"recommended minimum of {MIN_EVIDENCE}. With limited information, system results "
f"cannot be guaranteed to be reliable. If the patient has additional symptoms, "
f"risk factors, or a chest X-ray available, consider adding these."
)
}
return None
def risk_label(risk_level) -> str:
if isinstance(risk_level, (int, float)):
if risk_level >= 7: return "High"
elif risk_level >= 4: return "Medium"
else: return "Low"
return str(risk_level) if risk_level else "Unknown"
def write_risk_formula() -> str:
risk_formula = """
similarity = 1 - normalised_distance
total_weight += risk_score / 10
weighted_score += ((0.25 * similarity) + (0.3 * (risk/10)) + (0.45 * (rarity/10)))
final_score = (weighted_score / total_weight) * 10
"""
return risk_formula
################################################################################################################################################
### PATIENT SELECTION/CREATION
def create_patient(db, name: str, dob: str, sex: str) -> dict:
"""
Create a new patient record.
"""
pid = db.create_patient(name, dob, sex)
if pid is None:
return {
"type": "ERROR",
"message": "Patient already exists."
}
return {
"type": "PATIENT_CREATED",
"pid": pid
}
def select_patient(db, pid:int) -> dict:
"""
Validate that a patient ID exists, if so then select this patient.
"""
patients = db.list_patients()
for p in patients:
if str(pid) == str(p):
return {
"type": "PATIENT_SELECTED",
"pid": pid
}
return {
"type": "ERROR",
"message": f"Patient ID {pid} not found."
}
def list_all_patients(db) -> dict:
"""
Return a list of all existing patients in the db.
"""
return {
"type": "PATIENT_LIST",
"patients": db.list_patients()
}
################################################################################################################################################
### SESSION MANAGEMENT
def start_session(session_manager: SessionManager, pid:int) -> dict:
"""
Start a new session.
"""
sid = session_manager.start_session(pid)
existing_evidence = session_manager.get_patient_evidence(pid)
existing_evidence_list = (
existing_evidence["symptoms"] +
existing_evidence["risk_factors"] +
existing_evidence["imaging"]
)
session_manager.add_patient_records_evidence(sid, existing_evidence_list)
return{
"type": "SESSION_STARTED",
"sid": sid,
"existing_evidence": existing_evidence
}
def end_session(session_manager: SessionManager, sid: int) -> dict:
"""
End current session + persist
"""
session_manager.end_session(sid)
return{
"type": "SESSION_ENDED"
}
################################################################################################################################################
### IMAGE PIPELINE
def process_image_upload(session_manager: SessionManager, sid: str, pid: int, image_path: str) -> dict:
try:
session_manager.add_xray(sid, image_path)
imaging_detections, cnn_outputs = cnn_model(image_path)
print(f"\n\nIMAGING OUTPUT:\n{imaging_detections}\n{cnn_outputs}")
if imaging_detections:
session_manager.add_imaging_evidence(sid, imaging_detections)
return risk_assessment_pipeline(session_manager, sid, pid, imaging_input=True)
else:
return {
"type": "TEXT",
"message": "X-ray processed but no significant findings detected."
}
except Exception as e:
return {"type": "ERROR", "message": str(e)}
################################################################################################################################################
### TEXT PIPELINE
def process_text_input(session_manager: SessionManager, sid: int, pid: int, user_input: str, last_vector_search_output: list = None, last_info_array=None, last_risk_level=None) -> dict:
"""
Process a free-text user query.
"""
try:
### INTENT ROUTING
routing_result = route_intent(user_input, pid, sid)
if routing_result["status"] == "NEEDS_CLARIFICATION":
return {
"type": "INTENT_CLARIFICATION",
"confidence": routing_result["confidence"],
"options": routing_result["options"],
"pending_input": user_input
}
intent = routing_result["intent"]
if intent == "FOLLOW_UP_EXPLANATION":
return {
"type": "TEXT",
"message": handle_follow_up_question(user_input, last_info_array, last_risk_level)
}
if intent == "SOURCE_REQUEST":
if last_vector_search_output:
return {
"type": "TEXT",
"message": get_sources(last_vector_search_output)
}
return {
"type": "TEXT",
"message": "There has been no activity in this session to get sources for. Please provide symptoms, risk factors, or upload an X-ray first before requesting for sources."
}
if intent == "HELP_OR_OTHER":
return {
"type": "TEXT",
"message": help(user_input)
}
if intent == "PATIENT_EVIDENCE_QUERY":
extracted_evidence = extract_symptoms(user_input)
print(f"\n\nEXTRACTED EVIDENCE: {extracted_evidence}")
vague = check_for_vague(extracted_evidence)
if vague:
return vague_clarification_response(vague[0], extracted_evidence)
session_manager.add_text_evidence(sid, extracted_evidence)
return risk_assessment_pipeline(session_manager, sid, pid, False)
return {
"type": "TEXT",
"message": help(user_input)}
except Exception as e:
return {
"type": "ERROR",
"message": str(e)
}
################################################################################################################################################
### RISK ASSESSMENT PIPLINE
def risk_assessment_pipeline(session_manager: SessionManager, sid: int, pid:int, imaging_input: bool) -> dict:
"""
Runs the vector search, risk scoring, and explanation generation pipeline
using all evidence accumulated in the current session so far.
"""
### COMBINE EVIDENCE
combined_evidence = list(session_manager.get_combined_current_loop_evidence(sid))
print(f"\n\nCOMBINED EVIDENCE: {combined_evidence}")
combined_evidence = session_manager.get_combined_session_evidence(sid, combined_evidence)
print(f"\n\nCOMBINED EVIDENCE: {combined_evidence}")
### VECTOR SEARCH
dynamic_threshold = get_dynamic_threshold(len(combined_evidence))
vector_search_output = get_final_symptom_list(dynamic_threshold, combined_evidence)
# if imaging_input:
# vector_search_output = [
# item for item in vector_search_output
# if "imaging" in item[2].get("category", "").lower()
# ]
info_array = context_and_name(vector_search_output)
### LOG VALIDATED EVIDENCE TO SESSION STATE
validated_symptoms = []
validated_risk_factors = []
validated_imaging = []
formatted_vector_results = []
for item in vector_search_output:
factor = item[0]
distance = item[1] # normalised distance
metadata = item[2]
category = metadata.get("category", "").lower()
context = metadata.get("context", "")
source = metadata.get("source", "")
similarity_score = float(1 - distance)
print(f"similarity score: {similarity_score}")
risk_val = metadata.get("risk score", "N/A")
rarity_val = metadata.get("rarity", "N/A")
formatted_vector_results.append({
"name": factor,
"category": category,
"score": similarity_score,
"context": context,
"source": source,
"risk_score": risk_val,
"rarity": rarity_val,
"metadata": {
"risk_score": risk_val,
"rarity": rarity_val,
}
})
if category == "symptom":
validated_symptoms.append(factor)
elif "risk" in category:
validated_risk_factors.append(factor)
elif "imaging" in category:
validated_imaging.append(factor)
# Sort by score descending so UI can easily slice top-k
formatted_vector_results.sort(key=lambda x: x["score"], reverse=True)
session_manager.add_validated_evidence(sid, symptoms=validated_symptoms, risk_factors=validated_risk_factors, imaging_detections=validated_imaging)
### RISK SCORING
risk_level, risk_formula = calculate_risk_with_formula(vector_search_output)
session_manager.assign_risk(sid, risk_level)
### GENERATE EXPLANATION
generated_explanation = gen_explanation(info_array, risk_level)
### SOURCES
sources = get_sources(vector_search_output)
### MANAGE SESSION EVIDENCE
existing_evidence = session_manager.get_patient_evidence(pid)
session_only_evidence = session_manager.get_session_only_evidence(sid, pid)
print(f"existing evidence:\n{existing_evidence}")
print(f"validated evidence:\n{session_manager.get_validated_evidence(sid)}")
print(f"session only evidence:\n{session_only_evidence}")
### MIN EVIDENCE CHECK
evidence_warning = check_sufficient_evidence(session_manager, sid)
session_manager.clear_current_loop_evidence(sid)
### FINAL OUTPUT - RISK_ASSESSMENT_OUTPUT
return {
"type": "RISK_ASSESSMENT_OUTPUT",
"risk_level": risk_level,
"risk_label": risk_label(risk_level),
# Table 1 — all evidence split by source (history vs this session)
"existing_evidence": existing_evidence,
"session_evidence": {
"symptoms": session_only_evidence.get("symptoms", []),
"risk_factors": session_only_evidence.get("risk_factors", []),
"imaging": session_only_evidence.get("imaging", []),
},
"info_array": info_array,
"last_risk_level": risk_level,
# Table 2 — vector search results (sorted by score, UI slices top-k)
"vector_results": formatted_vector_results,
# Transparency box
"risk_formula": risk_formula,
# Narrative explanation
"explanation": generated_explanation,
# Sources list
"sources": list(sources),
# Pass back raw output so SOURCE_REQUEST can use it next turn
"vector_search_output": vector_search_output,
# None if sufficient, dict if not - to display warning message so user is aware the assesment was done with scarce evidence.
"evidence_warning": evidence_warning,
}
################################################################################################################################################
### CLARIFICATIONS
def resolve_intent_clarification(session_manager: SessionManager, sid: int, pid: int, selected_option: str, text_input: str, last_vector_search_output: list = None, last_info_array=None, last_risk_level=None) -> dict:
"""
Resolves Intent clarification.
Called after the user selects their intended action from INTENT_CLARIFICATION buttons.
"""
intent = OPTION_TO_INTENT.get(selected_option)
if intent is None:
return {
"type": "ERROR",
"message": f"Unknown option: {selected_option}"
}
if intent == "FOLLOW_UP_EXPLANATION":
return {
"type": "TEXT",
"message": handle_follow_up_question(text_input, last_info_array, last_risk_level)
}
if intent == "SOURCE_REQUEST":
if last_vector_search_output:
return {
"type": "TEXT",
"message": get_sources(last_vector_search_output)
}
return {
"type": "TEXT",
"message": "There has been no activity in this session to get sources for."
}
if intent == "HELP_OR_OTHER":
return {
"type": "TEXT",
"message": help(text_input)
}
if intent == "PATIENT_EVIDENCE_QUERY":
extracted_evidence = extract_symptoms(text_input)
vague = check_for_vague(extracted_evidence)
if vague:
return vague_clarification_response(vague[0], extracted_evidence)
session_manager.add_text_evidence(sid, extracted_evidence)
return risk_assessment_pipeline(session_manager, sid, pid, False)
return {
"type": "TEXT",
"message": help(text_input)
}
def resolve_vague_symptom(session_manager: SessionManager, sid: int, pid:int, vague_symptom: str, selected_option: str, extracted_evidence: list) -> dict:
"""
Resolve vague evidence calrification.
Called after the user selects an option from a VAGUE_CLARIFICATION prompt.
"""
try:
updated_evidence = list(extracted_evidence)
if vague_symptom in updated_evidence:
updated_evidence.remove(vague_symptom)
updated_evidence.append(selected_option)
vague = check_for_vague(updated_evidence)
if vague:
return vague_clarification_response(vague[0], updated_evidence)
# All vague symptoms resolved
session_manager.add_text_evidence(sid, updated_evidence)
return risk_assessment_pipeline(session_manager, sid, pid, False)
except Exception as e:
return {
"type": "ERROR",
"message": str(e)
}
################################################################################################################################################
def main():
"""
Testing only — the Streamlit UI does not call this.
"""
session_manager = SessionManager(test_mode=False)
db = session_manager.db
print("=== LUNG CANCER DIAGNOSIS ASSISTANT ===\n")
### PATIENT SELECTION/CREATION
pid = None
while pid is None:
action = input("Create / select / list patients? ").lower().strip()
if action == "create":
name = input("Name: ")
dob = input("DOB (YYYY-MM-DD): ")
sex = input("Sex (M/F): ")
result = create_patient(db, name, dob, sex)
if result["type"] == "ERROR":
print(result["message"])
else:
pid = result["pid"]
elif action == "select":
pid = int(input("Patient ID: "))
else:
print(list_all_patients(db)["patients"])
print(f"\nUsing patient {pid}")
### START SESSION
session_result = start_session(session_manager, pid)
sid = session_result["sid"]
print(f"Session started: {sid}")
### GET EXISTING EVIDENCE - PATIENT HISTORY
print(f"Existing Evidence: {session_result['existing_evidence']}")
last_vector_search_output = None
last_info_array = None
last_risk_level = None
### MAIN LOOP
while True:
user_input = input("\nEnter text, 'upload' for X-ray, or 'end': ").strip()
if user_input.lower() == "end":
break
### IMAGING PIPELINE
if user_input.lower() == "upload":
response = process_image_upload(session_manager, sid, pid, IMAGE_PATH)
### TEXT PIPELINE
else:
response = process_text_input(
session_manager, sid, pid, user_input, last_vector_search_output, last_info_array, last_risk_level
)
### HANDLE CALRIFICATIONS
while response["type"] in ("INTENT_CLARIFICATION", "VAGUE_CLARIFICATION"):
if response["type"] == "INTENT_CLARIFICATION":
print(f"\nConfidence: {response['confidence']:.2f}. Please choose:")
for i, opt in enumerate(response["options"]):
print(f" {i+1}. {opt}")
choice = int(input("Enter number: ")) - 1
selected = response["options"][choice]
response = resolve_intent_clarification(session_manager, sid, pid, selected, response["pending_input"], last_vector_search_output, last_info_array, last_risk_level)
elif response["type"] == "VAGUE_CLARIFICATION":
print(f"\n{response['question']}")
for i, opt in enumerate(response["options"]):
print(f" {i+1}. {opt['label']}")
choice = int(input("Enter number: ")) - 1
selected_key = response["options"][choice]["key"]
response = resolve_vague_symptom(session_manager, sid, pid, response["vague_symptom"], selected_key, response["pending_evidence"])
### FINAL OUTPUT
if response["type"] == "RISK_ASSESSMENT_OUTPUT":
last_vector_search_output = response.get("vector_search_output")
print("\n\n\n\n")
print(f"\nRisk Level: {response['risk_level']}")
print(f"Formula:\n{response['risk_formula']}")
print(f"Explanation: {response['explanation']}")
print(f"\nPATIENT HISTORY:")
print(f"{response['existing_evidence']}")
print(f"\n Current SESSION EVIDENCE:")
print(f"Symptoms: {response['session_evidence']['symptoms']}")
print(f"Risk Factors:{response['session_evidence']['risk_factors']}")
print(f"Imaging: {response['session_evidence']['imaging']}")
print(f"Sources: {response['sources']}")
print(f"Formatted_vector_output: {response['vector_results']}")
last_info_array = response.get("info_array")
last_risk_level = response.get("last_risk_level")
elif response["type"] == "TEXT":
print(f"\n{response['message']}")
elif response["type"] == "ERROR":
print(f"\nERROR: {response['message']}")
if response.get("evidence_warning"):
print(f"\n{response['evidence_warning']['message']}")
### END SESSION
end_session(session_manager, sid)
print("\nSession ended and persisted to database.")
################################################################################################################################################
def run_test_main(action_param, name_param, dob_param, sex_param, pid_param, image_path_param, query_param:list, intent_choice_param:list, vague_choice_param:list, expected_risk=None, expected_symptoms=None, expected_risk_factors=None, expected_imaging=None):
"""
Testing only — the Streamlit UI does not call this.
"""
session_manager = SessionManager(test_mode=True)
db = session_manager.db
print("=== LUNG CANCER DIAGNOSIS ASSISTANT ===\n")
### PATIENT SELECTION/CREATION
pid = None
while pid is None:
action = action_param.lower().strip()
if action == "create":
name = name_param
dob = dob_param
sex = sex_param
result = create_patient(db, name, dob, sex)
if result["type"] == "ERROR":
print(result["message"])
existing = db.list_patients()
for p in existing:
if p[1] == name_param:
pid = p[0]
break
else:
pid = result["pid"]
elif action == "select":
pid = int(pid_param)
else:
print(list_all_patients(db)["patients"])
print(f"\nUsing patient {pid}")
### START SESSION
session_result = start_session(session_manager, pid)
sid = session_result["sid"]
print(f"Session started: {sid}")
### GET EXISTING EVIDENCE - PATIENT HISTORY
print(f"Existing Evidence: {session_result['existing_evidence']}")
last_vector_search_output = None
final_output = None
last_info_array = None
last_risk_level = None
results = {
"session_success": final_output is not None,
"explanation_generated": False,
"risk_correct": None,
"symptoms_correct": None,
"risk_factors_correct": None,
"imaging_correct": None,
"total_latency": 0.0,
}
### OPTIONAL IMAGE
if image_path_param and image_path_param.lower() != "none":
image_response = process_image_upload(session_manager, sid, pid, image_path_param)
assert image_response["type"] in ("RISK_ASSESSMENT_OUTPUT", "TEXT")
if image_response["type"] == "RISK_ASSESSMENT_OUTPUT":
final_output = image_response
last_vector_search_output = image_response.get("vector_search_output")
### MAIN LOOP
i = 0
while i < len(query_param):
user_input = query_param[i].strip()
if user_input.lower() == "end":
break
### IMAGING PIPELINE
if user_input.lower() == "upload":
response = process_image_upload(session_manager, sid, pid, IMAGE_PATH)
### TEXT PIPELINE
else:
start_time = time.time()
response = process_text_input(
session_manager, sid, pid, user_input, last_vector_search_output, last_info_array, last_risk_level
)
latency = time.time() - start_time
print(f"Query {i+1} latency: {latency:.3f}s")
results["total_latency"] += latency
### HANDLE CALRIFICATIONS
while response["type"] in ("INTENT_CLARIFICATION", "VAGUE_CLARIFICATION"):
if response["type"] == "INTENT_CLARIFICATION":
print(f"\nConfidence: {response['confidence']:.2f}. Please choose:")
for j, opt in enumerate(response["options"]):
print(f" {j+1}. {opt}")
raw = intent_choice_param[i] if j < len(intent_choice_param) else ""
if not raw.strip():
choice = 0
else:
choice = int(raw) - 1
selected = response["options"][choice]
response = resolve_intent_clarification(session_manager, sid, pid, selected, response["pending_input"], last_vector_search_output, last_info_array, last_risk_level)
elif response["type"] == "VAGUE_CLARIFICATION":
print(f"\n{response['question']}")
for j, opt in enumerate(response["options"]):
print(f" {j+1}. {opt['label']}")
raw = vague_choice_param[i] if j < len(vague_choice_param) else ""
if not raw.strip():
choice = 0
else:
choice = int(raw) - 1
selected_key = response["options"][choice]["key"]
response = resolve_vague_symptom(session_manager, sid, pid, response["vague_symptom"], selected_key, response["pending_evidence"])
### FINAL OUTPUT
if response["type"] == "RISK_ASSESSMENT_OUTPUT":
final_output = response
last_vector_search_output = response.get("vector_search_output")
print("\n\n\n\n")
print(f"\nRisk Level: {response['risk_level']}")
print(f"Formula:\n{response['risk_formula']}")
print(f"Explanation: {response['explanation']}")
print(f"Symptoms: {response['session_evidence']['symptoms']}")
print(f"Risk Factors:{response['session_evidence']['risk_factors']}")
print(f"Imaging: {response['session_evidence']['imaging']}")
print(f"Sources: {response['sources']}")
last_info_array = response.get("info_array")
last_risk_level = response.get("last_risk_level")
elif response["type"] == "TEXT":
print(f"\n{response['message']}")
elif response["type"] == "ERROR":
print(f"\nERROR: {response['message']}")
if response.get("evidence_warning"):
print(f"\n{response['evidence_warning']['message']}")
i+=1
### METRICS EVALUATION
results["query_count"] = i
results["session_success"] = final_output is not None
if final_output:
results["explanation_generated"] = len(final_output.get("explanation", "")) > 0
# RISK DEBUG
if expected_risk:
print(f"\n=== RISK DEBUG ===")
print(f"Expected: {expected_risk} (type: {type(expected_risk)})")
print(f"Actual: {final_output['risk_level']} (type: {type(final_output['risk_level'])})")
results["risk_correct"] = abs(int(final_output["risk_level"]) - int(expected_risk)) <= 1
print(f"Match: {results['risk_correct']}")
# SYMPTOMS DEBUG
if expected_symptoms:
print(f"\n=== SYMPTOMS DEBUG ===")
predicted = set(s.strip() for s in final_output["session_evidence"]["symptoms"])
expected = set(s.strip() for s in expected_symptoms.split(","))
print(f"Expected: {expected}")
print(f"Predicted: {predicted}")
print(f"Missing: {expected - predicted}")
print(f"Extra: {predicted - expected}")
results["symptoms_correct"] = expected.issubset(predicted)
print(f"Match: {results['symptoms_correct']}")
# RISK FACTORS DEBUG
if expected_risk_factors:
print(f"\n=== RISK FACTORS DEBUG ===")
predicted_rf = set(s.strip() for s in final_output["session_evidence"]["risk_factors"])
expected_rf = set(s.strip() for s in expected_risk_factors.split(","))
print(f"Expected: {expected_rf}")
print(f"Predicted: {predicted_rf}")
print(f"Missing: {expected_rf - predicted_rf}")
print(f"Extra: {predicted_rf - expected_rf}")
results["risk_factors_correct"] = expected_rf.issubset(predicted_rf)
print(f"Match: {results['risk_factors_correct']}")
# IMAGING DEBUG
if expected_imaging:
print(f"\n=== IMAGING DEBUG ===")
predicted_img = set(s.strip() for s in final_output["session_evidence"]["imaging"])
expected_img = set(s.strip() for s in expected_imaging.split(","))
print(f"Expected: {expected_img}")
print(f"Predicted: {predicted_img}")
print(f"Missing: {expected_img - predicted_img}")
print(f"Extra: {predicted_img - expected_img}")
results["imaging_correct"] = expected_img.issubset(predicted_img)
print(f"Match: {results['imaging_correct']}")
### END SESSION
end_session(session_manager, sid)
print("\nSession ended and persisted to database.")
print(f"Eval Results: {results}")
return results
if __name__ == "__main__":
main()