| from fastapi import FastAPI |
| from pydantic import BaseModel |
| import pandas as pd |
| from sentence_transformers import SentenceTransformer |
| import chromadb |
| from fastapi.middleware.cors import CORSMiddleware |
| import uvicorn |
| import requests |
| |
| app = FastAPI() |
|
|
| origins = [ |
| "http://localhost:5173", |
| "localhost:5173" |
| ] |
|
|
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=origins, |
| allow_credentials=True, |
| allow_methods=["*"], |
| allow_headers=["*"] |
| ) |
|
|
| |
| df = pd.read_csv("hf://datasets/QuyenAnhDE/Diseases_Symptoms/Diseases_Symptoms.csv") |
| df['Symptoms'] = df['Symptoms'].str.split(',') |
| df['Symptoms'] = df['Symptoms'].apply(lambda x: [s.strip() for s in x]) |
|
|
| model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
| client = chromadb.PersistentClient(path='./chromadb') |
| collection = client.get_or_create_collection(name="symptomsvector") |
|
|
| class SymptomQuery(BaseModel): |
| symptom: str |
|
|
| |
| @app.post("/find_matching_symptoms") |
| def find_matching_symptoms(query: SymptomQuery): |
| |
| symptoms = query.symptom.split(',') |
| all_results = [] |
|
|
| for symptom in symptoms: |
| symptom = symptom.strip() |
| query_embedding = model.encode([symptom]) |
|
|
| |
| results = collection.query( |
| query_embeddings=query_embedding.tolist(), |
| n_results=3 |
| ) |
| all_results.extend(results['documents'][0]) |
|
|
| |
| matching_symptoms = list(dict.fromkeys(all_results)) |
|
|
| return {"matching_symptoms": matching_symptoms} |
|
|
| |
| @app.post("/find_matching_diseases") |
| def find_matching_diseases(query: SymptomQuery): |
| |
| query_embedding = model.encode([query.symptom]) |
|
|
| |
| results = collection.query( |
| query_embeddings=query_embedding.tolist(), |
| n_results=5 |
| ) |
|
|
| |
| matching_symptoms = results['documents'][0] |
|
|
| |
| matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))] |
|
|
| return {"matching_diseases": matching_diseases['Name'].tolist()} |
|
|
| |
| @app.post("/find_disease_list") |
| def find_disease_list(query: SymptomQuery): |
| |
| query_embedding = model.encode([query.symptom]) |
|
|
| |
| results = collection.query( |
| query_embeddings=query_embedding.tolist(), |
| n_results=5 |
| ) |
|
|
| |
| matching_symptoms = results['documents'][0] |
|
|
| |
| matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in matching_symptoms for s in x))] |
|
|
| |
| disease_list = [] |
| symptoms_list = [] |
| unique_symptoms_list = [] |
| for _, row in matching_diseases.iterrows(): |
| disease_info = { |
| 'Disease': row['Name'], |
| 'Symptoms': row['Symptoms'], |
| 'Treatments': row['Treatments'] |
| } |
| disease_list.append(disease_info) |
| symptoms_info = row['Symptoms'] |
| symptoms_list.append(symptoms_info) |
| for i in range(len(symptoms_list)): |
| for j in range(len(symptoms_list[i])): |
| if symptoms_list[i][j] not in unique_symptoms_list: |
| unique_symptoms_list.append(symptoms_list[i][j].lower()) |
| return {"disease_list": disease_list, "unique_symptoms_list": unique_symptoms_list} |
|
|
| class SelectedSymptomsQuery(BaseModel): |
| selected_symptoms: list |
|
|
| @app.post("/find_disease") |
| def find_disease(query: SelectedSymptomsQuery): |
| selected_symptoms = query.selected_symptoms |
| |
| matching_diseases = df[df['Symptoms'].apply(lambda x: any(s in x for s in selected_symptoms))] |
|
|
| |
| matching_diseases['match_count'] = matching_diseases['Symptoms'].apply(lambda x: sum(s in selected_symptoms for s in x)) |
| matching_diseases = matching_diseases.sort_values(by='match_count', ascending=False) |
|
|
| |
| disease_list = [] |
| max_match_count_disease = None |
| max_match_count = -1 |
|
|
| for _, row in matching_diseases.iterrows(): |
| disease_info = { |
| 'Disease': row['Name'], |
| 'Symptoms': row['Symptoms'], |
| 'Treatments': row['Treatments'], |
| 'MatchCount': row['match_count'] |
| } |
| disease_list.append(disease_info) |
|
|
| |
| if row['match_count'] > max_match_count: |
| max_match_count = row['match_count'] |
| max_match_count_disease = disease_info |
|
|
| return {"disease_list": disease_list, "max_match_count_disease": max_match_count_disease} |
| class DiseaseListQuery(BaseModel): |
| disease_list: list |
|
|
| class DiseaseDetail(BaseModel): |
| Disease: str |
| Symptoms: list |
| Treatments: str |
| MatchCount: int |
|
|
| @app.post("/pass2llm") |
| def pass2llm(query: DiseaseDetail): |
| |
| disease_list_details = query |
|
|
| |
| headers = { |
| "Authorization": "Bearer 2npJaJjnLBj1RGPcGf0QiyAAJHJ_5qqtw2divkpoAipqN9WLG", |
| "Ngrok-Version": "2" |
| } |
| response = requests.get("https://api.ngrok.com/endpoints", headers=headers) |
|
|
| |
| if response.status_code == 200: |
| llm_api_response = response.json() |
| public_url = llm_api_response['endpoints'][0]['public_url'] |
|
|
| |
| prompt = f"Here is a list of diseases and their details: {disease_list_details}. Please generate a summary." |
|
|
| |
| llm_headers = { |
| "Content-Type": "application/json" |
| } |
| llm_payload = { |
| "model": "llama3", |
| "prompt": prompt, |
| "stream": False |
| } |
| llm_response = requests.post(f"{public_url}/api/generate", headers=llm_headers, json=llm_payload) |
|
|
| |
| if llm_response.status_code == 200: |
| llm_response_json = llm_response.json() |
| return {"message": "Successfully passed to LLM!", "llm_response": llm_response_json.get("response")} |
| else: |
| return {"message": "Failed to get response from LLM!", "error": llm_response.text} |
| else: |
| return {"message": "Failed to get public URL from Ngrok!", "error": response.text} |
| |
| |
| |
|
|