import sys import os import time sys.path.insert(0, os.path.dirname(__file__)) from session_management.session_manager import SessionManager from intent_classification.intent_router_ml import route_intent from imaging.image_processing import cnn_model from symptom_extraction.symptom_extractor import extract_symptoms from semantic_search.retrieve import calculate_risk, get_final_symptom_list, context_and_name, get_sources, get_dynamic_threshold from session_management.ambiguity import check_for_vague, VAGUE_SYMPTOMS from explanation_generation.generate_explanation import gen_explanation from explanation_generation.follow_up_handler import handle_follow_up_question IMAGE_PATH = "demo_images/img6.jpg" MIN_EVIDENCE = 3 ################################################################################################################################################ ### INTENT OPTIONS OPTION_TO_INTENT = { "Add new evidence": "PATIENT_EVIDENCE_QUERY", "Ask for explanation": "FOLLOW_UP_EXPLANATION", "Request source": "SOURCE_REQUEST", "General help": "HELP_OR_OTHER", } ### HELPER FUNCTIONS def vague_clarification_response(vague_symptom: str, current_evidence: list) -> dict: vague_evidence = VAGUE_SYMPTOMS[vague_symptom.lower()] options = [ {"key": k, "label": v} for k, v in vague_evidence["options"].items() ] return { "type": "VAGUE_CLARIFICATION", "vague_symptom": vague_symptom, "question": vague_evidence["question"], "options": options, "pending_evidence": list(current_evidence), } def calculate_risk_with_formula(vector_search_output: list) -> tuple[str, str]: risk_level = calculate_risk(vector_search_output) risk_formula = write_risk_formula() return risk_level, risk_formula def help(text: str) -> str: return "Go to the app's help page for more info." def check_sufficient_evidence(session_manager: SessionManager, sid: int) -> dict | None: validated = session_manager.get_validated_evidence(sid) total = ( len(validated.get("symptoms", [])) + len(validated.get("risk_factors", [])) + len(validated.get("imaging_detections", [])) ) if total < MIN_EVIDENCE: return { "type": "LOW_EVIDENCE_WARNING", "total_evidence": total, "min_required": MIN_EVIDENCE, "message": ( f"Only {total} piece(s) of evidence recorded, which is below the " f"recommended minimum of {MIN_EVIDENCE}. With limited information, system results " f"cannot be guaranteed to be reliable. If the patient has additional symptoms, " f"risk factors, or a chest X-ray available, consider adding these." ) } return None def risk_label(risk_level) -> str: if isinstance(risk_level, (int, float)): if risk_level >= 7: return "High" elif risk_level >= 4: return "Medium" else: return "Low" return str(risk_level) if risk_level else "Unknown" def write_risk_formula() -> str: risk_formula = """ similarity = 1 - normalised_distance total_weight += risk_score / 10 weighted_score += ((0.25 * similarity) + (0.3 * (risk/10)) + (0.45 * (rarity/10))) final_score = (weighted_score / total_weight) * 10 """ return risk_formula ################################################################################################################################################ ### PATIENT SELECTION/CREATION def create_patient(db, name: str, dob: str, sex: str) -> dict: """ Create a new patient record. """ pid = db.create_patient(name, dob, sex) if pid is None: return { "type": "ERROR", "message": "Patient already exists." } return { "type": "PATIENT_CREATED", "pid": pid } def select_patient(db, pid:int) -> dict: """ Validate that a patient ID exists, if so then select this patient. """ patients = db.list_patients() for p in patients: if str(pid) == str(p): return { "type": "PATIENT_SELECTED", "pid": pid } return { "type": "ERROR", "message": f"Patient ID {pid} not found." } def list_all_patients(db) -> dict: """ Return a list of all existing patients in the db. """ return { "type": "PATIENT_LIST", "patients": db.list_patients() } ################################################################################################################################################ ### SESSION MANAGEMENT def start_session(session_manager: SessionManager, pid:int) -> dict: """ Start a new session. """ sid = session_manager.start_session(pid) existing_evidence = session_manager.get_patient_evidence(pid) existing_evidence_list = ( existing_evidence["symptoms"] + existing_evidence["risk_factors"] + existing_evidence["imaging"] ) session_manager.add_patient_records_evidence(sid, existing_evidence_list) return{ "type": "SESSION_STARTED", "sid": sid, "existing_evidence": existing_evidence } def end_session(session_manager: SessionManager, sid: int) -> dict: """ End current session + persist """ session_manager.end_session(sid) return{ "type": "SESSION_ENDED" } ################################################################################################################################################ ### IMAGE PIPELINE def process_image_upload(session_manager: SessionManager, sid: str, pid: int, image_path: str) -> dict: try: session_manager.add_xray(sid, image_path) imaging_detections, cnn_outputs = cnn_model(image_path) print(f"\n\nIMAGING OUTPUT:\n{imaging_detections}\n{cnn_outputs}") if imaging_detections: session_manager.add_imaging_evidence(sid, imaging_detections) return risk_assessment_pipeline(session_manager, sid, pid, imaging_input=True) else: return { "type": "TEXT", "message": "X-ray processed but no significant findings detected." } except Exception as e: return {"type": "ERROR", "message": str(e)} ################################################################################################################################################ ### TEXT PIPELINE def process_text_input(session_manager: SessionManager, sid: int, pid: int, user_input: str, last_vector_search_output: list = None, last_info_array=None, last_risk_level=None) -> dict: """ Process a free-text user query. """ try: ### INTENT ROUTING routing_result = route_intent(user_input, pid, sid) if routing_result["status"] == "NEEDS_CLARIFICATION": return { "type": "INTENT_CLARIFICATION", "confidence": routing_result["confidence"], "options": routing_result["options"], "pending_input": user_input } intent = routing_result["intent"] if intent == "FOLLOW_UP_EXPLANATION": return { "type": "TEXT", "message": handle_follow_up_question(user_input, last_info_array, last_risk_level) } if intent == "SOURCE_REQUEST": if last_vector_search_output: return { "type": "TEXT", "message": get_sources(last_vector_search_output) } return { "type": "TEXT", "message": "There has been no activity in this session to get sources for. Please provide symptoms, risk factors, or upload an X-ray first before requesting for sources." } if intent == "HELP_OR_OTHER": return { "type": "TEXT", "message": help(user_input) } if intent == "PATIENT_EVIDENCE_QUERY": extracted_evidence = extract_symptoms(user_input) print(f"\n\nEXTRACTED EVIDENCE: {extracted_evidence}") vague = check_for_vague(extracted_evidence) if vague: return vague_clarification_response(vague[0], extracted_evidence) session_manager.add_text_evidence(sid, extracted_evidence) return risk_assessment_pipeline(session_manager, sid, pid, False) return { "type": "TEXT", "message": help(user_input)} except Exception as e: return { "type": "ERROR", "message": str(e) } ################################################################################################################################################ ### RISK ASSESSMENT PIPLINE def risk_assessment_pipeline(session_manager: SessionManager, sid: int, pid:int, imaging_input: bool) -> dict: """ Runs the vector search, risk scoring, and explanation generation pipeline using all evidence accumulated in the current session so far. """ ### COMBINE EVIDENCE combined_evidence = list(session_manager.get_combined_current_loop_evidence(sid)) print(f"\n\nCOMBINED EVIDENCE: {combined_evidence}") combined_evidence = session_manager.get_combined_session_evidence(sid, combined_evidence) print(f"\n\nCOMBINED EVIDENCE: {combined_evidence}") ### VECTOR SEARCH dynamic_threshold = get_dynamic_threshold(len(combined_evidence)) vector_search_output = get_final_symptom_list(dynamic_threshold, combined_evidence) # if imaging_input: # vector_search_output = [ # item for item in vector_search_output # if "imaging" in item[2].get("category", "").lower() # ] info_array = context_and_name(vector_search_output) ### LOG VALIDATED EVIDENCE TO SESSION STATE validated_symptoms = [] validated_risk_factors = [] validated_imaging = [] formatted_vector_results = [] for item in vector_search_output: factor = item[0] distance = item[1] # normalised distance metadata = item[2] category = metadata.get("category", "").lower() context = metadata.get("context", "") source = metadata.get("source", "") similarity_score = float(1 - distance) print(f"similarity score: {similarity_score}") risk_val = metadata.get("risk score", "N/A") rarity_val = metadata.get("rarity", "N/A") formatted_vector_results.append({ "name": factor, "category": category, "score": similarity_score, "context": context, "source": source, "risk_score": risk_val, "rarity": rarity_val, "metadata": { "risk_score": risk_val, "rarity": rarity_val, } }) if category == "symptom": validated_symptoms.append(factor) elif "risk" in category: validated_risk_factors.append(factor) elif "imaging" in category: validated_imaging.append(factor) # Sort by score descending so UI can easily slice top-k formatted_vector_results.sort(key=lambda x: x["score"], reverse=True) session_manager.add_validated_evidence(sid, symptoms=validated_symptoms, risk_factors=validated_risk_factors, imaging_detections=validated_imaging) ### RISK SCORING risk_level, risk_formula = calculate_risk_with_formula(vector_search_output) session_manager.assign_risk(sid, risk_level) ### GENERATE EXPLANATION generated_explanation = gen_explanation(info_array, risk_level) ### SOURCES sources = get_sources(vector_search_output) ### MANAGE SESSION EVIDENCE existing_evidence = session_manager.get_patient_evidence(pid) session_only_evidence = session_manager.get_session_only_evidence(sid, pid) print(f"existing evidence:\n{existing_evidence}") print(f"validated evidence:\n{session_manager.get_validated_evidence(sid)}") print(f"session only evidence:\n{session_only_evidence}") ### MIN EVIDENCE CHECK evidence_warning = check_sufficient_evidence(session_manager, sid) session_manager.clear_current_loop_evidence(sid) ### FINAL OUTPUT - RISK_ASSESSMENT_OUTPUT return { "type": "RISK_ASSESSMENT_OUTPUT", "risk_level": risk_level, "risk_label": risk_label(risk_level), # Table 1 — all evidence split by source (history vs this session) "existing_evidence": existing_evidence, "session_evidence": { "symptoms": session_only_evidence.get("symptoms", []), "risk_factors": session_only_evidence.get("risk_factors", []), "imaging": session_only_evidence.get("imaging", []), }, "info_array": info_array, "last_risk_level": risk_level, # Table 2 — vector search results (sorted by score, UI slices top-k) "vector_results": formatted_vector_results, # Transparency box "risk_formula": risk_formula, # Narrative explanation "explanation": generated_explanation, # Sources list "sources": list(sources), # Pass back raw output so SOURCE_REQUEST can use it next turn "vector_search_output": vector_search_output, # None if sufficient, dict if not - to display warning message so user is aware the assesment was done with scarce evidence. "evidence_warning": evidence_warning, } ################################################################################################################################################ ### CLARIFICATIONS def resolve_intent_clarification(session_manager: SessionManager, sid: int, pid: int, selected_option: str, text_input: str, last_vector_search_output: list = None, last_info_array=None, last_risk_level=None) -> dict: """ Resolves Intent clarification. Called after the user selects their intended action from INTENT_CLARIFICATION buttons. """ intent = OPTION_TO_INTENT.get(selected_option) if intent is None: return { "type": "ERROR", "message": f"Unknown option: {selected_option}" } if intent == "FOLLOW_UP_EXPLANATION": return { "type": "TEXT", "message": handle_follow_up_question(text_input, last_info_array, last_risk_level) } if intent == "SOURCE_REQUEST": if last_vector_search_output: return { "type": "TEXT", "message": get_sources(last_vector_search_output) } return { "type": "TEXT", "message": "There has been no activity in this session to get sources for." } if intent == "HELP_OR_OTHER": return { "type": "TEXT", "message": help(text_input) } if intent == "PATIENT_EVIDENCE_QUERY": extracted_evidence = extract_symptoms(text_input) vague = check_for_vague(extracted_evidence) if vague: return vague_clarification_response(vague[0], extracted_evidence) session_manager.add_text_evidence(sid, extracted_evidence) return risk_assessment_pipeline(session_manager, sid, pid, False) return { "type": "TEXT", "message": help(text_input) } def resolve_vague_symptom(session_manager: SessionManager, sid: int, pid:int, vague_symptom: str, selected_option: str, extracted_evidence: list) -> dict: """ Resolve vague evidence calrification. Called after the user selects an option from a VAGUE_CLARIFICATION prompt. """ try: updated_evidence = list(extracted_evidence) if vague_symptom in updated_evidence: updated_evidence.remove(vague_symptom) updated_evidence.append(selected_option) vague = check_for_vague(updated_evidence) if vague: return vague_clarification_response(vague[0], updated_evidence) # All vague symptoms resolved session_manager.add_text_evidence(sid, updated_evidence) return risk_assessment_pipeline(session_manager, sid, pid, False) except Exception as e: return { "type": "ERROR", "message": str(e) } ################################################################################################################################################ def main(): """ Testing only — the Streamlit UI does not call this. """ session_manager = SessionManager(test_mode=False) db = session_manager.db print("=== LUNG CANCER DIAGNOSIS ASSISTANT ===\n") ### PATIENT SELECTION/CREATION pid = None while pid is None: action = input("Create / select / list patients? ").lower().strip() if action == "create": name = input("Name: ") dob = input("DOB (YYYY-MM-DD): ") sex = input("Sex (M/F): ") result = create_patient(db, name, dob, sex) if result["type"] == "ERROR": print(result["message"]) else: pid = result["pid"] elif action == "select": pid = int(input("Patient ID: ")) else: print(list_all_patients(db)["patients"]) print(f"\nUsing patient {pid}") ### START SESSION session_result = start_session(session_manager, pid) sid = session_result["sid"] print(f"Session started: {sid}") ### GET EXISTING EVIDENCE - PATIENT HISTORY print(f"Existing Evidence: {session_result['existing_evidence']}") last_vector_search_output = None last_info_array = None last_risk_level = None ### MAIN LOOP while True: user_input = input("\nEnter text, 'upload' for X-ray, or 'end': ").strip() if user_input.lower() == "end": break ### IMAGING PIPELINE if user_input.lower() == "upload": response = process_image_upload(session_manager, sid, pid, IMAGE_PATH) ### TEXT PIPELINE else: response = process_text_input( session_manager, sid, pid, user_input, last_vector_search_output, last_info_array, last_risk_level ) ### HANDLE CALRIFICATIONS while response["type"] in ("INTENT_CLARIFICATION", "VAGUE_CLARIFICATION"): if response["type"] == "INTENT_CLARIFICATION": print(f"\nConfidence: {response['confidence']:.2f}. Please choose:") for i, opt in enumerate(response["options"]): print(f" {i+1}. {opt}") choice = int(input("Enter number: ")) - 1 selected = response["options"][choice] response = resolve_intent_clarification(session_manager, sid, pid, selected, response["pending_input"], last_vector_search_output, last_info_array, last_risk_level) elif response["type"] == "VAGUE_CLARIFICATION": print(f"\n{response['question']}") for i, opt in enumerate(response["options"]): print(f" {i+1}. {opt['label']}") choice = int(input("Enter number: ")) - 1 selected_key = response["options"][choice]["key"] response = resolve_vague_symptom(session_manager, sid, pid, response["vague_symptom"], selected_key, response["pending_evidence"]) ### FINAL OUTPUT if response["type"] == "RISK_ASSESSMENT_OUTPUT": last_vector_search_output = response.get("vector_search_output") print("\n\n\n\n") print(f"\nRisk Level: {response['risk_level']}") print(f"Formula:\n{response['risk_formula']}") print(f"Explanation: {response['explanation']}") print(f"\nPATIENT HISTORY:") print(f"{response['existing_evidence']}") print(f"\n Current SESSION EVIDENCE:") print(f"Symptoms: {response['session_evidence']['symptoms']}") print(f"Risk Factors:{response['session_evidence']['risk_factors']}") print(f"Imaging: {response['session_evidence']['imaging']}") print(f"Sources: {response['sources']}") print(f"Formatted_vector_output: {response['vector_results']}") last_info_array = response.get("info_array") last_risk_level = response.get("last_risk_level") elif response["type"] == "TEXT": print(f"\n{response['message']}") elif response["type"] == "ERROR": print(f"\nERROR: {response['message']}") if response.get("evidence_warning"): print(f"\n{response['evidence_warning']['message']}") ### END SESSION end_session(session_manager, sid) print("\nSession ended and persisted to database.") ################################################################################################################################################ def run_test_main(action_param, name_param, dob_param, sex_param, pid_param, image_path_param, query_param:list, intent_choice_param:list, vague_choice_param:list, expected_risk=None, expected_symptoms=None, expected_risk_factors=None, expected_imaging=None): """ Testing only — the Streamlit UI does not call this. """ session_manager = SessionManager(test_mode=True) db = session_manager.db print("=== LUNG CANCER DIAGNOSIS ASSISTANT ===\n") ### PATIENT SELECTION/CREATION pid = None while pid is None: action = action_param.lower().strip() if action == "create": name = name_param dob = dob_param sex = sex_param result = create_patient(db, name, dob, sex) if result["type"] == "ERROR": print(result["message"]) existing = db.list_patients() for p in existing: if p[1] == name_param: pid = p[0] break else: pid = result["pid"] elif action == "select": pid = int(pid_param) else: print(list_all_patients(db)["patients"]) print(f"\nUsing patient {pid}") ### START SESSION session_result = start_session(session_manager, pid) sid = session_result["sid"] print(f"Session started: {sid}") ### GET EXISTING EVIDENCE - PATIENT HISTORY print(f"Existing Evidence: {session_result['existing_evidence']}") last_vector_search_output = None final_output = None last_info_array = None last_risk_level = None results = { "session_success": final_output is not None, "explanation_generated": False, "risk_correct": None, "symptoms_correct": None, "risk_factors_correct": None, "imaging_correct": None, "total_latency": 0.0, } ### OPTIONAL IMAGE if image_path_param and image_path_param.lower() != "none": image_response = process_image_upload(session_manager, sid, pid, image_path_param) assert image_response["type"] in ("RISK_ASSESSMENT_OUTPUT", "TEXT") if image_response["type"] == "RISK_ASSESSMENT_OUTPUT": final_output = image_response last_vector_search_output = image_response.get("vector_search_output") ### MAIN LOOP i = 0 while i < len(query_param): user_input = query_param[i].strip() if user_input.lower() == "end": break ### IMAGING PIPELINE if user_input.lower() == "upload": response = process_image_upload(session_manager, sid, pid, IMAGE_PATH) ### TEXT PIPELINE else: start_time = time.time() response = process_text_input( session_manager, sid, pid, user_input, last_vector_search_output, last_info_array, last_risk_level ) latency = time.time() - start_time print(f"Query {i+1} latency: {latency:.3f}s") results["total_latency"] += latency ### HANDLE CALRIFICATIONS while response["type"] in ("INTENT_CLARIFICATION", "VAGUE_CLARIFICATION"): if response["type"] == "INTENT_CLARIFICATION": print(f"\nConfidence: {response['confidence']:.2f}. Please choose:") for j, opt in enumerate(response["options"]): print(f" {j+1}. {opt}") raw = intent_choice_param[i] if j < len(intent_choice_param) else "" if not raw.strip(): choice = 0 else: choice = int(raw) - 1 selected = response["options"][choice] response = resolve_intent_clarification(session_manager, sid, pid, selected, response["pending_input"], last_vector_search_output, last_info_array, last_risk_level) elif response["type"] == "VAGUE_CLARIFICATION": print(f"\n{response['question']}") for j, opt in enumerate(response["options"]): print(f" {j+1}. {opt['label']}") raw = vague_choice_param[i] if j < len(vague_choice_param) else "" if not raw.strip(): choice = 0 else: choice = int(raw) - 1 selected_key = response["options"][choice]["key"] response = resolve_vague_symptom(session_manager, sid, pid, response["vague_symptom"], selected_key, response["pending_evidence"]) ### FINAL OUTPUT if response["type"] == "RISK_ASSESSMENT_OUTPUT": final_output = response last_vector_search_output = response.get("vector_search_output") print("\n\n\n\n") print(f"\nRisk Level: {response['risk_level']}") print(f"Formula:\n{response['risk_formula']}") print(f"Explanation: {response['explanation']}") print(f"Symptoms: {response['session_evidence']['symptoms']}") print(f"Risk Factors:{response['session_evidence']['risk_factors']}") print(f"Imaging: {response['session_evidence']['imaging']}") print(f"Sources: {response['sources']}") last_info_array = response.get("info_array") last_risk_level = response.get("last_risk_level") elif response["type"] == "TEXT": print(f"\n{response['message']}") elif response["type"] == "ERROR": print(f"\nERROR: {response['message']}") if response.get("evidence_warning"): print(f"\n{response['evidence_warning']['message']}") i+=1 ### METRICS EVALUATION results["query_count"] = i results["session_success"] = final_output is not None if final_output: results["explanation_generated"] = len(final_output.get("explanation", "")) > 0 # RISK DEBUG if expected_risk: print(f"\n=== RISK DEBUG ===") print(f"Expected: {expected_risk} (type: {type(expected_risk)})") print(f"Actual: {final_output['risk_level']} (type: {type(final_output['risk_level'])})") results["risk_correct"] = abs(int(final_output["risk_level"]) - int(expected_risk)) <= 1 print(f"Match: {results['risk_correct']}") # SYMPTOMS DEBUG if expected_symptoms: print(f"\n=== SYMPTOMS DEBUG ===") predicted = set(s.strip() for s in final_output["session_evidence"]["symptoms"]) expected = set(s.strip() for s in expected_symptoms.split(",")) print(f"Expected: {expected}") print(f"Predicted: {predicted}") print(f"Missing: {expected - predicted}") print(f"Extra: {predicted - expected}") results["symptoms_correct"] = expected.issubset(predicted) print(f"Match: {results['symptoms_correct']}") # RISK FACTORS DEBUG if expected_risk_factors: print(f"\n=== RISK FACTORS DEBUG ===") predicted_rf = set(s.strip() for s in final_output["session_evidence"]["risk_factors"]) expected_rf = set(s.strip() for s in expected_risk_factors.split(",")) print(f"Expected: {expected_rf}") print(f"Predicted: {predicted_rf}") print(f"Missing: {expected_rf - predicted_rf}") print(f"Extra: {predicted_rf - expected_rf}") results["risk_factors_correct"] = expected_rf.issubset(predicted_rf) print(f"Match: {results['risk_factors_correct']}") # IMAGING DEBUG if expected_imaging: print(f"\n=== IMAGING DEBUG ===") predicted_img = set(s.strip() for s in final_output["session_evidence"]["imaging"]) expected_img = set(s.strip() for s in expected_imaging.split(",")) print(f"Expected: {expected_img}") print(f"Predicted: {predicted_img}") print(f"Missing: {expected_img - predicted_img}") print(f"Extra: {predicted_img - expected_img}") results["imaging_correct"] = expected_img.issubset(predicted_img) print(f"Match: {results['imaging_correct']}") ### END SESSION end_session(session_manager, sid) print("\nSession ended and persisted to database.") print(f"Eval Results: {results}") return results if __name__ == "__main__": main()