model / main2.py
swarit222's picture
Upload 3 files
34e8e70 verified
raw
history blame
2.2 kB
import pandas as pd
def search_trials(user_age, user_sex, user_state, user_keywords, csv_path="clinical_trials_cleaned_merged.csv"):
"""
Search for recruiting US clinical trials matching the user's demographics & optional keywords.
Returns ALL available columns from the dataset.
"""
# === Load dataset ===
df = pd.read_csv(csv_path)
# Drop missing critical columns
df = df.dropna(subset=["MinimumAge", "MaximumAge", "Sex", "OverallStatus"])
# Keep only US & recruiting trials
df = df[df["LocationCountry"] == "United States"]
df = df[df["OverallStatus"].str.lower() == "recruiting"]
# Convert ages to numeric
def parse_age(age_str):
if pd.isnull(age_str):
return None
parts = str(age_str).split()
try:
return int(parts[0])
except:
return None
df["MinAgeNum"] = df["MinimumAge"].apply(parse_age)
df["MaxAgeNum"] = df["MaximumAge"].apply(parse_age)
# Prepare user's keywords list
if isinstance(user_keywords, str):
keywords = [k.strip().lower() for k in user_keywords.split(",") if k.strip()]
elif isinstance(user_keywords, list):
keywords = [str(k).strip().lower() for k in user_keywords if str(k).strip()]
else:
keywords = []
# === Create masks ===
sex_mask = df["Sex"].str.lower().isin([str(user_sex).lower(), "all"])
age_mask = (df["MinAgeNum"] <= int(user_age)) & (df["MaxAgeNum"] >= int(user_age))
state_mask = df["LocationState"].str.lower() == str(user_state).lower()
if keywords:
def row_matches_any_keyword(row):
row_as_str = " ".join(str(x).lower() for x in row.values if pd.notnull(x))
return any(k in row_as_str for k in keywords)
keyword_mask = df.apply(row_matches_any_keyword, axis=1)
else:
keyword_mask = True
# Apply all filters and return ALL columns
filtered_df = df[sex_mask & age_mask & state_mask & keyword_mask].reset_index(drop=True)
# Drop helper numeric age cols if you don’t want them visible
filtered_df = filtered_df.drop(columns=["MinAgeNum", "MaxAgeNum"], errors="ignore")
return filtered_df