TrialGPT / main.py
tannu038's picture
Upload main.py
5c0b8b6 verified
import os
import re
import faiss
import torch
import numpy as np
import pandas as pd
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoModel, AutoTokenizer
# Hugging Face Cache Directory
os.environ["HF_HOME"] = "/app/huggingface"
os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
app = FastAPI()
# --- Load Clinical Trials CSV ---
csv_path = "ctg-studies-obesity.csv"
if os.path.exists(csv_path):
df_trials = pd.read_csv(csv_path)
print("βœ… CSV File Loaded Successfully!")
else:
raise FileNotFoundError(f"❌ CSV File Not Found: {csv_path}. Upload it first.")
# --- Rename columns to match the required fields ---
df_trials.rename(columns={
"NCT Number": "NCTID",
"Interventions": "Intervention",
"Phases": "Phase",
"Study Status": "Status",
"Completion Date": "Completion Date",
"Study Results": "Has Results",
"Sponsor": "Sponsor"
}, inplace=True)
# --- Load FAISS Index ---
dimension = 768
faiss_index_path = "clinical_trials.index"
if os.path.exists(faiss_index_path):
index = faiss.read_index(faiss_index_path)
print("βœ… FAISS Index Loaded!")
else:
index = faiss.IndexFlatL2(dimension)
print("⚠ FAISS Index Not Found! Using Empty Index.")
# --- Load Retrieval Model ---
retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
# --- Request Model ---
class QueryRequest(BaseModel):
text: str
top_k: int = 5
# --- Generate Embedding for Query ---
def generate_embedding(text):
inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
with torch.no_grad():
outputs = retrieval_model(**inputs)
return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
# --- Retrieve Clinical Trial Info ---
def get_trial_info(nct_id):
trial_info = df_trials[df_trials["NCTID"] == nct_id].fillna("N/A").to_dict(orient="records")
return trial_info[0] if trial_info else None
# --- Retrieval Endpoint (Returns Only Limited Data) ---
@app.post("/retrieve")
async def retrieve_trial(request: QueryRequest):
"""Retrieve Clinical Trial based on text (Shows Limited Info)"""
query_vector = generate_embedding(request.text)
distances, indices = index.search(query_vector, request.top_k)
matched_trials = []
for idx in indices[0]:
if idx < len(df_trials):
nct_id = df_trials.iloc[idx]["NCTID"]
trial_data = get_trial_info(nct_id)
if trial_data:
# Extract only required fields
filtered_trial_data = {
"NCTID": trial_data["NCTID"],
"Intervention": trial_data.get("Intervention", "N/A"),
"Phase": trial_data.get("Phase", "N/A"),
"Status": trial_data.get("Status", "N/A"),
"Completion Date": trial_data.get("Completion Date", "N/A"),
"Has Results": trial_data.get("Has Results", "N/A"),
"Sponsor": trial_data.get("Sponsor", "N/A"),
}
matched_trials.append(filtered_trial_data)
return {"matched_trials": matched_trials}
class StudyText(BaseModel):
text: str
def extract_study_timeline(text: str):
"""
Extracts Screening, Treatment, and Follow-up durations from a study timeline description.
Handles both structured and unstructured formats.
"""
# Screening phase patterns
screening = re.search(
r'(?:Screening|Pre-study observation|Initial Check)[^.\n]?(?:of|:|-|is|lasts)?\s(\d+)\s*weeks?',
text, re.IGNORECASE
)
# Treatment phase patterns
treatment = re.search(
r'(?:Treatment|Intervention|Therapy|Dosing phase|Main study(?:\s*period)?)[^.\n]?(?:of|:|-|is|lasts)?\s(\d+)\s*weeks?',
text, re.IGNORECASE
)
# Follow-up phase patterns
follow_up = re.search(
r'(?:Follow[-\s]up|Recovery|Post-study monitoring|Observation phase|After-treatment)[^.\n]?(?:of|:|-|is|lasts)?[^.\n]*?(\d+)\s*weeks?',
text, re.IGNORECASE
)
# Final timeline dictionary
timeline = {
"Screening": int(screening.group(1)) if screening else None,
"Treatment": int(treatment.group(1)) if treatment else None,
"Follow-Up": int(follow_up.group(1)) if follow_up else None
}
return timeline
@app.post("/extract-timeline/")
async def extract_timeline(request: StudyText):
return extract_study_timeline(request.text)
# --- Fetch Full Trial Details When Clicked ---
@app.get("/trial/{nct_id}")
async def get_trial_details(nct_id: str):
"""Fetch Full Details of a Clinical Trial"""
trial_data = get_trial_info(nct_id)
return {"trial_details": trial_data} if trial_data else {"error": "Trial not found"}
# --- Root Endpoint ---
@app.get("/")
async def root():
return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}