| import os |
| import re |
| import faiss |
| import torch |
| import numpy as np |
| import pandas as pd |
| from fastapi import FastAPI |
| from pydantic import BaseModel |
| from transformers import AutoModel, AutoTokenizer |
|
|
| |
| os.environ["HF_HOME"] = "/app/huggingface" |
| os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" |
|
|
| app = FastAPI() |
|
|
| |
| csv_path = "ctg-studies-obesity.csv" |
| if os.path.exists(csv_path): |
| df_trials = pd.read_csv(csv_path) |
| print("β
CSV File Loaded Successfully!") |
| else: |
| raise FileNotFoundError(f"β CSV File Not Found: {csv_path}. Upload it first.") |
|
|
| |
| df_trials.rename(columns={ |
| "NCT Number": "NCTID", |
| "Interventions": "Intervention", |
| "Phases": "Phase", |
| "Study Status": "Status", |
| "Completion Date": "Completion Date", |
| "Study Results": "Has Results", |
| "Sponsor": "Sponsor" |
| }, inplace=True) |
|
|
| |
| dimension = 768 |
| faiss_index_path = "clinical_trials.index" |
|
|
| if os.path.exists(faiss_index_path): |
| index = faiss.read_index(faiss_index_path) |
| print("β
FAISS Index Loaded!") |
| else: |
| index = faiss.IndexFlatL2(dimension) |
| print("β FAISS Index Not Found! Using Empty Index.") |
|
|
| |
| retrieval_model_name = "priyanandanwar/fine-tuned-gatortron" |
| retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name) |
| retrieval_model = AutoModel.from_pretrained(retrieval_model_name) |
|
|
| |
| class QueryRequest(BaseModel): |
| text: str |
| top_k: int = 5 |
|
|
| |
| def generate_embedding(text): |
| inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) |
| with torch.no_grad(): |
| outputs = retrieval_model(**inputs) |
| return outputs.last_hidden_state[:, 0, :].numpy() |
|
|
| |
| def get_trial_info(nct_id): |
| trial_info = df_trials[df_trials["NCTID"] == nct_id].fillna("N/A").to_dict(orient="records") |
| return trial_info[0] if trial_info else None |
|
|
| |
| @app.post("/retrieve") |
| async def retrieve_trial(request: QueryRequest): |
| """Retrieve Clinical Trial based on text (Shows Limited Info)""" |
| query_vector = generate_embedding(request.text) |
| distances, indices = index.search(query_vector, request.top_k) |
|
|
| matched_trials = [] |
| for idx in indices[0]: |
| if idx < len(df_trials): |
| nct_id = df_trials.iloc[idx]["NCTID"] |
| trial_data = get_trial_info(nct_id) |
|
|
| if trial_data: |
| |
| filtered_trial_data = { |
| "NCTID": trial_data["NCTID"], |
| "Intervention": trial_data.get("Intervention", "N/A"), |
| "Phase": trial_data.get("Phase", "N/A"), |
| "Status": trial_data.get("Status", "N/A"), |
| "Completion Date": trial_data.get("Completion Date", "N/A"), |
| "Has Results": trial_data.get("Has Results", "N/A"), |
| "Sponsor": trial_data.get("Sponsor", "N/A"), |
| } |
| matched_trials.append(filtered_trial_data) |
|
|
| return {"matched_trials": matched_trials} |
|
|
| class StudyText(BaseModel): |
| text: str |
|
|
| def extract_study_timeline(text: str): |
| """ |
| Extracts Screening, Treatment, and Follow-up durations from a study timeline description. |
| Handles both structured and unstructured formats. |
| """ |
|
|
| |
| screening = re.search( |
| r'(?:Screening|Pre-study observation|Initial Check)[^.\n]?(?:of|:|-|is|lasts)?\s(\d+)\s*weeks?', |
| text, re.IGNORECASE |
| ) |
|
|
| |
| treatment = re.search( |
| r'(?:Treatment|Intervention|Therapy|Dosing phase|Main study(?:\s*period)?)[^.\n]?(?:of|:|-|is|lasts)?\s(\d+)\s*weeks?', |
| text, re.IGNORECASE |
| ) |
|
|
| |
| follow_up = re.search( |
| r'(?:Follow[-\s]up|Recovery|Post-study monitoring|Observation phase|After-treatment)[^.\n]?(?:of|:|-|is|lasts)?[^.\n]*?(\d+)\s*weeks?', |
| text, re.IGNORECASE |
| ) |
|
|
| |
| timeline = { |
| "Screening": int(screening.group(1)) if screening else None, |
| "Treatment": int(treatment.group(1)) if treatment else None, |
| "Follow-Up": int(follow_up.group(1)) if follow_up else None |
| } |
|
|
| return timeline |
|
|
| @app.post("/extract-timeline/") |
| async def extract_timeline(request: StudyText): |
| return extract_study_timeline(request.text) |
|
|
|
|
| |
| |
| @app.get("/trial/{nct_id}") |
| async def get_trial_details(nct_id: str): |
| """Fetch Full Details of a Clinical Trial""" |
| trial_data = get_trial_info(nct_id) |
| return {"trial_details": trial_data} if trial_data else {"error": "Trial not found"} |
|
|
| |
| @app.get("/") |
| async def root(): |
| return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"} |