tannu038 commited on
Commit
5c0b8b6
·
verified ·
1 Parent(s): f5ada3d

Upload main.py

Browse files
Files changed (1) hide show
  1. main.py +149 -0
main.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import faiss
4
+ import torch
5
+ import numpy as np
6
+ import pandas as pd
7
+ from fastapi import FastAPI
8
+ from pydantic import BaseModel
9
+ from transformers import AutoModel, AutoTokenizer
10
+
11
+ # Hugging Face Cache Directory
12
+ os.environ["HF_HOME"] = "/app/huggingface"
13
+ os.environ["HF_HUB_DOWNLOAD_TIMEOUT"] = "60"
14
+
15
+ app = FastAPI()
16
+
17
+ # --- Load Clinical Trials CSV ---
18
+ csv_path = "ctg-studies-obesity.csv"
19
+ if os.path.exists(csv_path):
20
+ df_trials = pd.read_csv(csv_path)
21
+ print("✅ CSV File Loaded Successfully!")
22
+ else:
23
+ raise FileNotFoundError(f"❌ CSV File Not Found: {csv_path}. Upload it first.")
24
+
25
+ # --- Rename columns to match the required fields ---
26
+ df_trials.rename(columns={
27
+ "NCT Number": "NCTID",
28
+ "Interventions": "Intervention",
29
+ "Phases": "Phase",
30
+ "Study Status": "Status",
31
+ "Completion Date": "Completion Date",
32
+ "Study Results": "Has Results",
33
+ "Sponsor": "Sponsor"
34
+ }, inplace=True)
35
+
36
+ # --- Load FAISS Index ---
37
+ dimension = 768
38
+ faiss_index_path = "clinical_trials.index"
39
+
40
+ if os.path.exists(faiss_index_path):
41
+ index = faiss.read_index(faiss_index_path)
42
+ print("✅ FAISS Index Loaded!")
43
+ else:
44
+ index = faiss.IndexFlatL2(dimension)
45
+ print("⚠ FAISS Index Not Found! Using Empty Index.")
46
+
47
+ # --- Load Retrieval Model ---
48
+ retrieval_model_name = "priyanandanwar/fine-tuned-gatortron"
49
+ retrieval_tokenizer = AutoTokenizer.from_pretrained(retrieval_model_name)
50
+ retrieval_model = AutoModel.from_pretrained(retrieval_model_name)
51
+
52
+ # --- Request Model ---
53
+ class QueryRequest(BaseModel):
54
+ text: str
55
+ top_k: int = 5
56
+
57
+ # --- Generate Embedding for Query ---
58
+ def generate_embedding(text):
59
+ inputs = retrieval_tokenizer(text, return_tensors="pt", truncation=True, padding="max_length", max_length=512)
60
+ with torch.no_grad():
61
+ outputs = retrieval_model(**inputs)
62
+ return outputs.last_hidden_state[:, 0, :].numpy() # CLS Token Embedding
63
+
64
+ # --- Retrieve Clinical Trial Info ---
65
+ def get_trial_info(nct_id):
66
+ trial_info = df_trials[df_trials["NCTID"] == nct_id].fillna("N/A").to_dict(orient="records")
67
+ return trial_info[0] if trial_info else None
68
+
69
+ # --- Retrieval Endpoint (Returns Only Limited Data) ---
70
+ @app.post("/retrieve")
71
+ async def retrieve_trial(request: QueryRequest):
72
+ """Retrieve Clinical Trial based on text (Shows Limited Info)"""
73
+ query_vector = generate_embedding(request.text)
74
+ distances, indices = index.search(query_vector, request.top_k)
75
+
76
+ matched_trials = []
77
+ for idx in indices[0]:
78
+ if idx < len(df_trials):
79
+ nct_id = df_trials.iloc[idx]["NCTID"]
80
+ trial_data = get_trial_info(nct_id)
81
+
82
+ if trial_data:
83
+ # Extract only required fields
84
+ filtered_trial_data = {
85
+ "NCTID": trial_data["NCTID"],
86
+ "Intervention": trial_data.get("Intervention", "N/A"),
87
+ "Phase": trial_data.get("Phase", "N/A"),
88
+ "Status": trial_data.get("Status", "N/A"),
89
+ "Completion Date": trial_data.get("Completion Date", "N/A"),
90
+ "Has Results": trial_data.get("Has Results", "N/A"),
91
+ "Sponsor": trial_data.get("Sponsor", "N/A"),
92
+ }
93
+ matched_trials.append(filtered_trial_data)
94
+
95
+ return {"matched_trials": matched_trials}
96
+
97
+ class StudyText(BaseModel):
98
+ text: str
99
+
100
+ def extract_study_timeline(text: str):
101
+ """
102
+ Extracts Screening, Treatment, and Follow-up durations from a study timeline description.
103
+ Handles both structured and unstructured formats.
104
+ """
105
+
106
+ # Screening phase patterns
107
+ screening = re.search(
108
+ r'(?:Screening|Pre-study observation|Initial Check)[^.\n]?(?:of|:|-|is|lasts)?\s(\d+)\s*weeks?',
109
+ text, re.IGNORECASE
110
+ )
111
+
112
+ # Treatment phase patterns
113
+ treatment = re.search(
114
+ r'(?:Treatment|Intervention|Therapy|Dosing phase|Main study(?:\s*period)?)[^.\n]?(?:of|:|-|is|lasts)?\s(\d+)\s*weeks?',
115
+ text, re.IGNORECASE
116
+ )
117
+
118
+ # Follow-up phase patterns
119
+ follow_up = re.search(
120
+ r'(?:Follow[-\s]up|Recovery|Post-study monitoring|Observation phase|After-treatment)[^.\n]?(?:of|:|-|is|lasts)?[^.\n]*?(\d+)\s*weeks?',
121
+ text, re.IGNORECASE
122
+ )
123
+
124
+ # Final timeline dictionary
125
+ timeline = {
126
+ "Screening": int(screening.group(1)) if screening else None,
127
+ "Treatment": int(treatment.group(1)) if treatment else None,
128
+ "Follow-Up": int(follow_up.group(1)) if follow_up else None
129
+ }
130
+
131
+ return timeline
132
+
133
+ @app.post("/extract-timeline/")
134
+ async def extract_timeline(request: StudyText):
135
+ return extract_study_timeline(request.text)
136
+
137
+
138
+
139
+ # --- Fetch Full Trial Details When Clicked ---
140
+ @app.get("/trial/{nct_id}")
141
+ async def get_trial_details(nct_id: str):
142
+ """Fetch Full Details of a Clinical Trial"""
143
+ trial_data = get_trial_info(nct_id)
144
+ return {"trial_details": trial_data} if trial_data else {"error": "Trial not found"}
145
+
146
+ # --- Root Endpoint ---
147
+ @app.get("/")
148
+ async def root():
149
+ return {"message": "TrialGPT API is Running with FAISS-based Retrieval!"}