trialgpt-api-ui / app.py
tannu038's picture
Create app.py
835b58a verified
import os
import re
import faiss
import torch
import numpy as np
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.metrics.pairwise import cosine_similarity
# Hugging Face Cache Directory
os.environ["HF_HOME"] = "/app/huggingface"
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
app = FastAPI()
# --- Load Clinical Trials CSV ---
csv_path = "ctg-studies-obesity.csv"
if os.path.exists(csv_path):
df_trials = pd.read_csv(csv_path)
print("βœ… CSV File Loaded Successfully!")
else:
raise FileNotFoundError(f"❌ CSV File Not Found: {csv_path}. Upload it first.")
# --- Rename columns to match the required fields ---
df_trials.rename(columns={
"NCT Number": "NCTID",
"Interventions": "Intervention",
"Phases": "Phase",
"Study Status": "Status",
"Completion Date": "Completion Date",
"Study Results": "Has Results",
"Sponsor": "Sponsor"
}, inplace=True)
# --- Load FAISS Index ---
dimension = 768
faiss_index_path = "clinical_trials.index"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("βœ… FAISS Index Loaded!")
else:
index = faiss.IndexFlatL2(dimension)
print("⚠ FAISS Index Not Found! Using Empty Index.")
# --- Load Retrieval Model ---
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
# --- Request Models ---
class QueryRequest(BaseModel):
text: str
class StudyText(BaseModel):
text: str
# --- Generate Embedding for Query ---
def generate_embedding(text):
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
with torch.no_grad():
outputs = retrieval_model(**inputs)
return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
# --- Retrieve Clinical Trial Info ---
def get_trial_info(nct_id):
trial_info = df_trials[df_trials["NCTID"] == nct_id].fillna("N/A").to_dict(orient="records")
return trial_info[0] if trial_info else None
# βœ… --- Summary Extraction Function ---
def extract_summary(text, max_sentences=2):
if not isinstance(text, str) or not text.strip():
return "No summary available."
sentences = re.split(r'(?<=[.!?])\s+', text)
if len(sentences) <= max_sentences:
return text
vectorizer = TfidfVectorizer(stop_words="english")
tfidf_matrix = vectorizer.fit_transform(sentences)
similarity_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix)
sentence_scores = similarity_matrix.sum(axis=1)
ranked_sentences = sorted(
zip(sentence_scores, sentences), reverse=True
)[:max_sentences]
summary = " ".join([s[1] for s in sorted(ranked_sentences, key=lambda x: sentences.index(x[1]))])
return summary
# βœ… --- Timeline Extraction ---
def extract_study_timeline(text: str):
def extract(patterns):
for pattern in patterns:
match = re.search(pattern, text, re.IGNORECASE)
if match:
return int(match.group(1))
return None
screening_patterns = [
r'(?:Screening|Pre-study observation|Initial Check)[^.\n]*?(?:of|:|-|is|lasts)?\s*(\d+)\s*weeks?',
r'(\d+)\s*weeks[^.\n]*?(?:Screening|Pre-study observation|Initial Check)'
]
treatment_patterns = [
r'(?:Treatment|Intervention|Therapy|Dosing phase|Main study(?:\s*period)?)[^.\n]*?(?:of|:|-|is|lasts)?\s*(\d+)\s*weeks?',
r'(\d+)\s*weeks[^.\n]*?(?:Treatment|Intervention|Therapy|Dosing phase|Main study(?:\s*period)?)'
]
follow_up_patterns = [
r'(?:Follow[-\s]*up|Recovery|Post-study monitoring|Observation phase|After-treatment)[^.\n]*?(?:of|:|-|is|lasts)?\s*(\d+)\s*weeks?',
r'(\d+)\s*weeks[^.\n]*?(?:Follow[-\s]*up|Recovery|Post-study monitoring|Observation phase|After-treatment)'
]
timeline = {
"Screening": extract(screening_patterns),
"Treatment": extract(treatment_patterns),
"Follow-Up": extract(follow_up_patterns)
}
return timeline
@app.post("/extract-timeline/")
async def extract_timeline(request: StudyText):
return extract_study_timeline(request.text)
# βœ… --- Retrieval Endpoint with Summary ---
@app.post("/retrieve")
async def retrieve_trial(request: QueryRequest):
query_vector = generate_embedding(request.text)
total_trials = index.ntotal
distances, indices = index.search(query_vector, total_trials)
matched_trials = []
for idx in indices[0]:
if idx < len(df_trials):
nct_id = df_trials.iloc[idx]["NCTID"]
trial_data = get_trial_info(nct_id)
if trial_data:
# Try using Brief Summary, then Description or Detailed Description
text_for_summary = trial_data.get("Brief Summary") or trial_data.get("Description") or trial_data.get("Detailed Description") or ""
filtered_trial_data = {
"NCTID": trial_data["NCTID"],
"Intervention": trial_data.get("Intervention", "N/A"),
"Phase": trial_data.get("Phase", "N/A"),
"Status": trial_data.get("Status", "N/A"),
"Completion Date": trial_data.get("Completion Date", "N/A"),
"Has Results": trial_data.get("Has Results", "N/A"),
"Sponsor": trial_data.get("Sponsor", "N/A"),
"Summary": extract_summary(text_for_summary)
}
matched_trials.append(filtered_trial_data)
return {"matched_trials": matched_trials}
@app.get("/trial/{nct_id}")
async def get_trial_details(nct_id: str):
trial_data = get_trial_info(nct_id)
return {"trial_details": trial_data} if trial_data else {"error": "Trial not found"}
@app.get("/")
async def root():
return {"message": "TrialGPT API is Running with FAISS-based Retrieval, Timeline Extraction, and Summary Mining! 🎯"}