model / app.py
swarit222's picture
Update app.py
de55d66 verified
import gradio as gr
import pandas as pd
import re
from sklearn.feature_extraction.text import TfidfVectorizer
import numpy as np
from main2 import search_trials # Import your updated search_trials
PAGE_SIZE = 5
PREVIEW_WORDS = 100 # Number of words in collapsed preview
US_STATES = [
"Alabama", "Alaska", "Arizona", "Arkansas", "California", "Colorado", "Connecticut", "Delaware",
"Florida", "Georgia", "Hawaii", "Idaho", "Illinois", "Indiana", "Iowa", "Kansas", "Kentucky",
"Louisiana", "Maine", "Maryland", "Massachusetts", "Michigan", "Minnesota", "Mississippi",
"Missouri", "Montana", "Nebraska", "Nevada", "New Hampshire", "New Jersey", "New Mexico",
"New York", "North Carolina", "North Dakota", "Ohio", "Oklahoma", "Oregon", "Pennsylvania",
"Rhode Island", "South Carolina", "South Dakota", "Tennessee", "Texas", "Utah", "Vermont",
"Virginia", "Washington", "West Virginia", "Wisconsin", "Wyoming", "District of Columbia"
]
def split_sentences(text):
return [s.strip() for s in re.split(r'(?<!\w\.\w.)(?<![A-Z][a-z]\.)(?<=\.|\?|\!)\s', text) if s.strip()]
def build_input_text(row):
text_parts = [
f"Brief Summary: {row.get('BriefSummary', '')}",
f"Primary Outcome Measure: {row.get('PrimaryOutcomeMeasure', '')}",
f"Primary Outcome Description: {row.get('PrimaryOutcomeDescription', '')}",
f"Primary Completion Date: {row.get('PrimaryCompletionDate', '')}"
]
return " ".join([part for part in text_parts if part.strip()])
def generate_summary(row, max_sentences=7, min_sentence_length=5):
text = build_input_text(row)
if not text.strip():
return ""
sentences = split_sentences(text)
sentences = [s for s in sentences if len(s.split()) >= min_sentence_length]
if not sentences:
return ""
if len(sentences) <= max_sentences:
return " ".join(sentences)
vectorizer = TfidfVectorizer(stop_words="english")
tfidf_matrix = vectorizer.fit_transform(sentences)
scores = np.array(tfidf_matrix.sum(axis=1)).flatten()
position_weights = np.linspace(1.5, 1.0, num=len(sentences))
combined_scores = scores * position_weights
top_indices = combined_scores.argsort()[-max_sentences:][::-1]
top_indices = sorted(top_indices)
summary_sentences = []
for i in top_indices:
s = sentences[i]
if re.match(r"^(Start Date|Primary Completion Date|Intervention Name|Primary Outcome Measure|Primary Outcome Description):", s):
continue
summary_sentences.append(s)
if len(summary_sentences) < max_sentences:
for i in top_indices:
if len(summary_sentences) >= max_sentences:
break
if sentences[i] not in summary_sentences:
summary_sentences.append(sentences[i])
return " ".join(summary_sentences[:max_sentences])
def run_search(age, sex, state, keywords):
df = search_trials(
user_age=age,
user_sex=sex,
user_state=state,
user_keywords=keywords,
generate_summaries=False
)
if df.empty:
return pd.DataFrame(), 0, None
total_pages = (len(df) + PAGE_SIZE - 1) // PAGE_SIZE
page_df = df.iloc[:PAGE_SIZE].copy()
page_df['LaymanSummary'] = ""
return page_df, total_pages, df
def load_page(page_num, full_df):
if full_df is None or full_df.empty:
return pd.DataFrame()
start = page_num * PAGE_SIZE
end = start + PAGE_SIZE
page_df = full_df.iloc[start:end].copy()
page_df['LaymanSummary'] = page_df.apply(generate_summary, axis=1)
return page_df
def update_page_controls(page_num, total_pages):
prev_visible = gr.update(visible=page_num > 0)
next_visible = gr.update(visible=page_num < total_pages - 1)
page_text = f"Page {page_num + 1} of {total_pages}" if total_pages > 0 else ""
return prev_visible, next_visible, page_text
def hide_empty_columns(df):
cols_to_keep = []
for col in df.columns:
col_values = df[col].dropna().astype(str).str.strip()
if not col_values.empty and any(val != "" for val in col_values):
cols_to_keep.append(col)
return df[cols_to_keep]
def df_to_html_with_readmore(df: pd.DataFrame) -> str:
if df.empty:
return "<p>No matching trials found.</p>"
from html import escape
if "LaymanSummary" in df.columns:
cols = list(df.columns)
cols.insert(0, cols.pop(cols.index("LaymanSummary")))
df = df[cols]
df = hide_empty_columns(df)
html = ['''
<style>
table {
width: 100%;
border-collapse: collapse;
font-family: Arial, sans-serif;
}
th {
background-color: #007bff;
color: white;
padding: 12px;
text-align: left;
border: 1px solid #ddd;
}
td {
border: 1px solid #ddd;
padding: 12px;
vertical-align: top;
white-space: normal;
max-width: 1000px; /* 2.5x original 400px */
min-width: 1000px; /* force width */
word-wrap: break-word;
}
details summary {
cursor: pointer;
color: #007bff;
font-weight: bold;
}
details summary:after {
content: " (Read More)";
color: #0056b3;
font-weight: normal;
}
details[open] summary {
display: none; /* hide preview when expanded */
}
details div.full-text {
display: none;
}
details[open] div.full-text {
display: block;
margin-top: 8px;
}
</style>
''']
html.append('<table><thead><tr>')
for col in df.columns:
html.append(f'<th>{escape(col)}</th>')
html.append('</tr></thead><tbody>')
for _, row in df.iterrows():
html.append('<tr>')
for col in df.columns:
val = str(row[col])
words = val.split()
if len(words) > PREVIEW_WORDS:
short_text = escape(" ".join(words[:PREVIEW_WORDS]) + "...")
full_text = escape(val)
cell_html = f'''
<div>
<details>
<summary>{short_text}</summary>
<div class="full-text">{full_text}</div>
</details>
</div>
'''
else:
cell_html = f'<div>{escape(val)}</div>'
html.append(f'<td>{cell_html}</td>')
html.append('</tr>')
html.append('</tbody></table>')
return "".join(html)
def on_search(age, sex, state, keywords):
df_page, total_pages, full_df = run_search(age, sex, state, keywords)
page_num = 0
if not df_page.empty:
df_page = load_page(page_num, full_df)
prev_vis, next_vis, page_text = update_page_controls(page_num, total_pages)
html_output = df_to_html_with_readmore(df_page)
return html_output, page_text, prev_vis, next_vis, page_num, total_pages, full_df, gr.update(visible=False), gr.update(visible=True)
def on_page_change(increment, page_num, total_pages, full_df):
if full_df is None or full_df.empty:
return "<p>No matching trials found.</p>", "", gr.update(visible=False), gr.update(visible=False), 0
new_page = max(0, min(page_num + increment, total_pages - 1))
page_df = load_page(new_page, full_df)
prev_vis, next_vis, page_text = update_page_controls(new_page, total_pages)
html_output = df_to_html_with_readmore(page_df)
return html_output, page_text, prev_vis, next_vis, new_page
def show_input_page():
return gr.update(visible=True), gr.update(visible=False)
with gr.Blocks() as demo:
gr.Markdown("# Clinical Trials Search Tool with Pagination and Inline Read More")
with gr.Column(visible=True) as input_page:
gr.Markdown("Find **recruiting US clinical trials** that match your **age**, **sex**, **state**, and optional **keywords**.")
with gr.Row():
age_input = gr.Number(label="Your Age", value=30)
sex_input = gr.Dropdown(["Male", "Female", "All"], label="Sex", value="All")
with gr.Row():
state_input = gr.Dropdown(US_STATES, label="State", value="California")
keywords_input = gr.Textbox(label="Keywords", placeholder="e.g., Cancer, Diabetes")
search_btn = gr.Button("Search Trials")
with gr.Column(visible=False) as results_page:
output_html = gr.HTML()
total_pages_text = gr.Textbox(value="", interactive=False)
with gr.Row():
prev_btn = gr.Button("Previous Page")
next_btn = gr.Button("Next Page")
back_btn = gr.Button("Back")
page_num_state = gr.State(0)
total_pages_state = gr.State(0)
full_results_state = gr.State(None)
search_btn.click(
fn=on_search,
inputs=[age_input, sex_input, state_input, keywords_input],
outputs=[output_html, total_pages_text, prev_btn, next_btn, page_num_state, total_pages_state, full_results_state, input_page, results_page]
)
next_btn.click(
fn=on_page_change,
inputs=[gr.State(1), page_num_state, total_pages_state, full_results_state],
outputs=[output_html, total_pages_text, prev_btn, next_btn, page_num_state]
)
prev_btn.click(
fn=on_page_change,
inputs=[gr.State(-1), page_num_state, total_pages_state, full_results_state],
outputs=[output_html, total_pages_text, prev_btn, next_btn, page_num_state]
)
back_btn.click(
fn=show_input_page,
outputs=[input_page, results_page]
)
if __name__ == "__main__":
demo.launch()