model / main2.py
swarit222's picture
Update main2.py
02e2bed verified
import pandas as pd
import re
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
# Load & preprocess dataset once (global)
print("Loading and preprocessing dataset...")
df_full = pd.read_csv("clinical_trials_cleaned_merged.csv")
def parse_age(age_str):
if pd.isnull(age_str):
return None
parts = str(age_str).split()
try:
return int(parts[0])
except:
return None
df_full["MinAgeNum"] = df_full["MinimumAge"].apply(parse_age)
df_full["MaxAgeNum"] = df_full["MaximumAge"].apply(parse_age)
df_full["combined_text"] = df_full.astype(str).agg(" ".join, axis=1).str.lower()
print(f"Preprocessed {len(df_full)} US recruiting trials.")
def search_trials(user_age, user_sex, user_state, user_keywords, generate_summaries=True):
# Local helpers inside the function
def split_sentences(text):
# Improved sentence splitter
return [s.strip() for s in re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text) if s.strip()]
def build_input_text(row):
text_parts = [
f"Intervention Name: {row.get('InterventionName', '')}",
f"Intervention Description: {row.get('InterventionDescription', '')}",
f"Brief Summary: {row.get('BriefSummary', '')}",
f"Primary Outcome Measure: {row.get('PrimaryOutcomeMeasure', '')}",
f"Primary Outcome Description: {row.get('PrimaryOutcomeDescription', '')}",
f"Start Date: {row.get('StartDate', '')}",
f"Detailed Description: {row.get('DetailedDescription', '')}",
f"Eligibility Criteria: {row.get('EligibilityCriteria', '')}"
]
return " ".join([part for part in text_parts if part.strip()])
def generate_summary(row, max_sentences=7, min_sentence_length=5):
text = build_input_text(row)
if not text.strip():
return ""
sentences = split_sentences(text)
# Filter out very short sentences
sentences = [s for s in sentences if len(s.split()) >= min_sentence_length]
if not sentences:
return ""
if len(sentences) <= max_sentences:
return " ".join(sentences)
vectorizer = TfidfVectorizer(stop_words="english")
tfidf_matrix = vectorizer.fit_transform(sentences)
scores = np.array(tfidf_matrix.sum(axis=1)).flatten()
# Position weighting: earlier sentences weighted higher
position_weights = np.linspace(1.5, 1.0, num=len(sentences))
combined_scores = scores * position_weights
top_indices = combined_scores.argsort()[-max_sentences:][::-1]
top_indices = sorted(top_indices) # keep original order
summary_sentences = []
for i in top_indices:
s = sentences[i]
# Skip sentences that look like metadata labels
if re.match(r"^(Start Date|Primary Completion Date|Intervention Name|Primary Outcome Measure|Primary Outcome Description):", s):
continue
summary_sentences.append(s)
# If filtered too aggressively, add back more sentences from top indices
if len(summary_sentences) < max_sentences:
for i in top_indices:
if len(summary_sentences) >= max_sentences:
break
if sentences[i] not in summary_sentences:
summary_sentences.append(sentences[i])
return " ".join(summary_sentences[:max_sentences])
df = df_full.copy()
# Prepare keywords list
if isinstance(user_keywords, str):
keywords = [k.strip().lower() for k in user_keywords.split(",") if k.strip()]
elif isinstance(user_keywords, list):
keywords = [str(k).strip().lower() for k in user_keywords if str(k).strip()]
else:
keywords = []
sex_mask = df["Sex"].str.lower().isin([str(user_sex).lower(), "all"])
age_mask = (df["MinAgeNum"] <= int(user_age)) & (df["MaxAgeNum"] >= int(user_age))
state_mask = df["LocationState"].str.lower() == str(user_state).lower()
if keywords:
keyword_mask = df["combined_text"].apply(lambda txt: any(k in txt for k in keywords))
else:
keyword_mask = True
filtered_df = df[sex_mask & age_mask & state_mask & keyword_mask].reset_index(drop=True)
filtered_df = filtered_df.drop(columns=["MinAgeNum", "MaxAgeNum", "combined_text"], errors="ignore")
if generate_summaries and len(filtered_df) > 0:
print(f"Generating improved fast extractive summaries for {len(filtered_df)} filtered trials...")
filtered_df["LaymanSummary"] = filtered_df.apply(generate_summary, axis=1)
else:
filtered_df["LaymanSummary"] = ""
return filtered_df