Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import faiss | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import AutoModel, AutoTokenizer | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| # Hugging Face Cache Directory | |
| os.environ["HF_HOME"] = "/app/huggingface" | |
| os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60" | |
| app = FastAPI() | |
| # --- Load Clinical Trials CSV --- | |
| csv_path = "ctg-studies-obesity.csv" | |
| if os.path.exists(csv_path): | |
| df_trials = pd.read_csv(csv_path) | |
| print("β CSV File Loaded Successfully!") | |
| else: | |
| raise FileNotFoundError(f"β CSV File Not Found: {csv_path}. Upload it first.") | |
| # --- Rename columns to match the required fields --- | |
| df_trials.rename(columns={ | |
| "NCT Number": "NCTID", | |
| "Interventions": "Intervention", | |
| "Phases": "Phase", | |
| "Study Status": "Status", | |
| "Completion Date": "Completion Date", | |
| "Study Results": "Has Results", | |
| "Sponsor": "Sponsor" | |
| }, inplace=True) | |
| # --- Load FAISS Index --- | |
| dimension = 768 | |
| faiss_index_path = "clinical_trials.index" | |
| if os.path.exists(faiss_index_path): | |
| index = faiss.read_index(faiss_index_path) | |
| print("β FAISS Index Loaded!") | |
| else: | |
| index = faiss.IndexFlatL2(dimension) | |
| print("β FAISS Index Not Found! Using Empty Index.") | |
| # --- Load Retrieval Model --- | |
| retrieval_model_name = "priyanandanwar/fine-tuned-gatortron" | |
| retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name) | |
| retrieval_model = AutoModel.from_pretrained(retrieval_model_name) | |
| # --- Request Models --- | |
| class QueryRequest(BaseModel): | |
| text: str | |
| class StudyText(BaseModel): | |
| text: str | |
| # --- Generate Embedding for Query --- | |
| def generate_embedding(text): | |
| inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512) | |
| with torch.no_grad(): | |
| outputs = retrieval_model(**inputs) | |
| return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding | |
| # --- Retrieve Clinical Trial Info --- | |
| def get_trial_info(nct_id): | |
| trial_info = df_trials[df_trials["NCTID"] == nct_id].fillna("N/A").to_dict(orient="records") | |
| return trial_info[0] if trial_info else None | |
| # β --- Summary Extraction Function --- | |
| def extract_summary(text, max_sentences=2): | |
| if not isinstance(text, str) or not text.strip(): | |
| return "No summary available." | |
| sentences = re.split(r'(?<=[.!?])\s+', text) | |
| if len(sentences) <= max_sentences: | |
| return text | |
| vectorizer = TfidfVectorizer(stop_words="english") | |
| tfidf_matrix = vectorizer.fit_transform(sentences) | |
| similarity_matrix = cosine_similarity(tfidf_matrix, tfidf_matrix) | |
| sentence_scores = similarity_matrix.sum(axis=1) | |
| ranked_sentences = sorted( | |
| zip(sentence_scores, sentences), reverse=True | |
| )[:max_sentences] | |
| summary = " ".join([s[1] for s in sorted(ranked_sentences, key=lambda x: sentences.index(x[1]))]) | |
| return summary | |
| # β --- Timeline Extraction --- | |
| def extract_study_timeline(text: str): | |
| def extract(patterns): | |
| for pattern in patterns: | |
| match = re.search(pattern, text, re.IGNORECASE) | |
| if match: | |
| return int(match.group(1)) | |
| return None | |
| screening_patterns = [ | |
| r'(?:Screening|Pre-study observation|Initial Check)[^.\n]*?(?:of|:|-|is|lasts)?\s*(\d+)\s*weeks?', | |
| r'(\d+)\s*weeks[^.\n]*?(?:Screening|Pre-study observation|Initial Check)' | |
| ] | |
| treatment_patterns = [ | |
| r'(?:Treatment|Intervention|Therapy|Dosing phase|Main study(?:\s*period)?)[^.\n]*?(?:of|:|-|is|lasts)?\s*(\d+)\s*weeks?', | |
| r'(\d+)\s*weeks[^.\n]*?(?:Treatment|Intervention|Therapy|Dosing phase|Main study(?:\s*period)?)' | |
| ] | |
| follow_up_patterns = [ | |
| r'(?:Follow[-\s]*up|Recovery|Post-study monitoring|Observation phase|After-treatment)[^.\n]*?(?:of|:|-|is|lasts)?\s*(\d+)\s*weeks?', | |
| r'(\d+)\s*weeks[^.\n]*?(?:Follow[-\s]*up|Recovery|Post-study monitoring|Observation phase|After-treatment)' | |
| ] | |
| timeline = { | |
| "Screening": extract(screening_patterns), | |
| "Treatment": extract(treatment_patterns), | |
| "Follow-Up": extract(follow_up_patterns) | |
| } | |
| return timeline | |
| async def extract_timeline(request: StudyText): | |
| return extract_study_timeline(request.text) | |
| # β --- Retrieval Endpoint with Summary --- | |
| async def retrieve_trial(request: QueryRequest): | |
| query_vector = generate_embedding(request.text) | |
| total_trials = index.ntotal | |
| distances, indices = index.search(query_vector, total_trials) | |
| matched_trials = [] | |
| for idx in indices[0]: | |
| if idx < len(df_trials): | |
| nct_id = df_trials.iloc[idx]["NCTID"] | |
| trial_data = get_trial_info(nct_id) | |
| if trial_data: | |
| # Try using Brief Summary, then Description or Detailed Description | |
| text_for_summary = trial_data.get("Brief Summary") or trial_data.get("Description") or trial_data.get("Detailed Description") or "" | |
| filtered_trial_data = { | |
| "NCTID": trial_data["NCTID"], | |
| "Intervention": trial_data.get("Intervention", "N/A"), | |
| "Phase": trial_data.get("Phase", "N/A"), | |
| "Status": trial_data.get("Status", "N/A"), | |
| "Completion Date": trial_data.get("Completion Date", "N/A"), | |
| "Has Results": trial_data.get("Has Results", "N/A"), | |
| "Sponsor": trial_data.get("Sponsor", "N/A"), | |
| "Summary": extract_summary(text_for_summary) | |
| } | |
| matched_trials.append(filtered_trial_data) | |
| return {"matched_trials": matched_trials} | |
| async def get_trial_details(nct_id: str): | |
| trial_data = get_trial_info(nct_id) | |
| return {"trial_details": trial_data} if trial_data else {"error": "Trial not found"} | |
| async def root(): | |
| return {"message": "TrialGPT API is Running with FAISS-based Retrieval, Timeline Extraction, and Summary Mining! π―"} | |