bipolar / src /Rag.py
ymali's picture
use openai oss
6dffeff
import os
import json
import time
import requests
import numpy as np
from dotenv import load_dotenv
from sentence_transformers import SentenceTransformer
from together import Together
from openai import OpenAI
global db, referenced_tables_db, embedder, index, llm_client
def load_json_to_db(file_path):
with open(file_path) as f:
db = json.load(f)
return db
# -------- Embedding Functions --------
def make_embeddings(embedder, embedder_name, db):
texts = [chunk['text'] for chunk in db]
embeddings = embedder.encode(texts, convert_to_numpy=True, batch_size=1, show_progress_bar=True)
return embeddings
def get_project_root():
return os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
def save_embeddings(embedder_name, embeddings):
root = get_project_root()
file_path = os.path.join(root, "data", "embeddings", f"{embedder_name.replace('/', '_')}.npy")
os.makedirs(os.path.dirname(file_path), exist_ok=True)
np.save(file_path, embeddings)
print(f"Saved embeddings to: {file_path}")
def load_embeddings(embedder_name):
root = get_project_root()
file_path = os.path.join(root, "data", "embeddings", f"{embedder_name.replace('/', '_')}.npy")
try:
embeddings = np.load(file_path, allow_pickle=True)
print(f"Loaded embeddings from: {file_path}")
except FileNotFoundError:
print(f"Embeddings not found. Recomputing for: {embedder_name}")
embeddings = make_embeddings(embedder, embedder_name, db)
save_embeddings(embedder_name, embeddings)
return embeddings
def load_embedder_with_fallbacks(embedder_name):
print(f"Loading embedder {embedder_name}")
model = SentenceTransformer(
embedder_name,
trust_remote_code=True,
tokenizer_kwargs={"padding_side": "left"},
device='cpu'
)
return model
# -------- Cosine Similarity Index (no FAISS) --------
def build_cosine_index(embeddings):
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
return embeddings / norms
def load_cosine_index(embedder_name):
embeddings = load_embeddings(embedder_name)
normalized_embeddings = build_cosine_index(embeddings)
return normalized_embeddings
# -------- Cosine Similarity Search (Brute Force) --------
def vector_search(query, embedder, db, index, referenced_table_db, k=6):
def get_detailed_instruct(task_description: str, query: str) -> str:
return f'Instruct: {task_description}\nQuery:{query}'
task = 'Given a search query, retrieve relevant passages that answer the query'
query_embedding = embedder.encode([get_detailed_instruct(task, query)], convert_to_numpy=True)
query_vec = query_embedding / np.linalg.norm(query_embedding)
cosine_similarities = np.dot(index, query_vec.T).flatten()
top_k_indices = np.argsort(-cosine_similarities)[:k]
results = []
referenced_tables = set()
existed_tables = set()
for i in top_k_indices:
results.append({
"text": db[i]['text'],
"section": db[i]['metadata']['section'],
"chunk_id": db[i]['metadata']['chunk_id'],
"similarity": float(cosine_similarities[i]),
})
if db[i]['metadata']['referee_id']:
existed_tables.add(db[i]['metadata']['referee_id'])
try:
if db[i]['metadata']['referenced_tables']:
referenced_tables.update(db[i]['metadata']['referenced_tables'])
except KeyError:
continue
table_to_add = [table for table in referenced_tables if table not in existed_tables]
for chunk in referenced_table_db:
if chunk['metadata']['referee_id'] in table_to_add:
results.append({
"text": chunk['text'],
"section": chunk['metadata']['section'],
"chunk_id": chunk['metadata']['chunk_id'],
})
return results
def load_together_llm_client():
load_dotenv()
return Together(api_key=os.getenv("TOGETHER_API_KEY"))
def load_nvidia_llm_client():
load_dotenv()
return OpenAI(
base_url="https://integrate.api.nvidia.com/v1",
api_key=os.getenv("NVIDIA_API_KEY"),
)
# -------- Prompt Construction --------
def construct_prompt(query, faiss_results):
with open("src/system_prompt.txt", "r") as f:
system_prompt = f.read().strip()
prompt = f"""
### System Prompt
{system_prompt}
### User Query
{query}
### Clinical Guidelines Context
"""
for res in faiss_results:
prompt += f"- reference: {res['section']}\n- This paragraph is from section: {res['text']}\n"
return prompt
def construct_prompt_with_memory(query, faiss_results, chat_history=None, history_limit=4):
with open("src/system_prompt.txt", "r") as f:
system_prompt = f.read().strip()
prompt = f"### System Prompt\n{system_prompt}\n\n"
if chat_history:
prompt += "### Chat History\n"
for m in chat_history[-history_limit:]:
prompt += f"{m['role'].title()}: {m['content']}\n"
prompt += "\n"
prompt += f"### User Query\n{query}\n\n"
prompt += "### Clinical Guidelines Context\n"
for res in faiss_results:
prompt += f"- reference: {res['section']}\n- This paragraph is from section: {res['text']}\n"
return prompt
def call_llm(llm_client, prompt, stream_flag=False, max_tokens=500, temperature=0.05, top_p=0.9, model_name="meta-llama/Llama-3.3-70B-Instruct-Turbo-Free"):
print(f"Calling LLM with model: {model_name}")
try:
if stream_flag:
def stream_generator():
response = llm_client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=True,
)
for chunk in response:
if chunk.choices and chunk.choices[0].delta.content:
yield chunk.choices[0].delta.content
return stream_generator()
else:
response = llm_client.chat.completions.create(
model=model_name,
messages=[{"role": "user", "content": prompt}],
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
stream=False,
)
return response.choices[0].message.content
except Exception as e:
print("Error in call_llm:", str(e))
import traceback
traceback.print_exc()
raise
def call_nvidia_llm(llm_client, prompt, stream_flag=False, max_tokens=4096, temperature=0.6, top_p=0.7, model_name="openai/gpt-oss-20b"):
print(f"Calling NVIDIA LLM with model: {model_name}")
try:
if stream_flag:
def stream_generator():
completion = llm_client.chat.completions.create(
model=model_name,
messages=[{"role":"user","content": prompt}],
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stream=True
)
for chunk in completion:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
return stream_generator()
else:
completion = llm_client.chat.completions.create(
model=model_name,
messages=[{"role":"user","content": prompt}],
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
stream=False
)
return completion.choices[0].message.content
except Exception as e:
print("Error in call_nvidia_llm:", str(e))
import traceback
traceback.print_exc()
raise
def call_ollama(prompt, model="mistral", stream_flag=False, max_tokens=500, temperature=0.05, top_p=0.9):
url = "http://localhost:11434/api/generate"
payload = {
"model": model,
"prompt": prompt,
"temperature": temperature,
"top_p": top_p,
"max_tokens": max_tokens,
"stream": True
}
with requests.post(url, json=payload, stream=True) as response:
for line in response.iter_lines():
if line:
try:
data = json.loads(line.decode("utf-8"))
yield data["response"]
except Exception:
continue
# -------- Main Assistant Entry Points --------
def launch_depression_assistant(embedder_name, designated_client=None):
global db, referenced_tables_db, embedder, index, llm_client
db = load_json_to_db("data/processed/guideline_db.json")
referenced_tables_db = load_json_to_db("data/processed/referenced_table_chunks.json")
embedder = load_embedder_with_fallbacks(embedder_name)
index = load_cosine_index(embedder_name)
if designated_client is None:
print("Attempting to load NVIDIA LLM client...")
try:
llm_client = load_nvidia_llm_client()
print("Successfully loaded NVIDIA LLM client.")
except Exception as e:
print(f"Failed to load NVIDIA LLM client: {e}")
print("Attempting to load Together LLM client as a fallback...")
try:
llm_client = load_together_llm_client()
print("Successfully loaded Together LLM client.")
except Exception as e:
print(f"Failed to load Together LLM client: {e}")
llm_client = None
else:
llm_client = designated_client
print(f"Using designated client: {type(llm_client).__name__}")
if llm_client is None:
print("Warning: No LLM client could be loaded. The assistant will not be able to generate responses.")
print("---------Depression Assistant is ready to use!--------------\n\n")
def depression_assistant(query, model_name=None, max_tokens=None, temperature=None, top_p=None, stream_flag=False, chat_history=None):
results = vector_search(query, embedder, db, index, referenced_tables_db, k=3)
prompt = construct_prompt_with_memory(query, results, chat_history=chat_history)
kwargs = {}
if model_name:
kwargs['model_name'] = model_name
if max_tokens:
kwargs['max_tokens'] = max_tokens
if temperature is not None:
kwargs['temperature'] = temperature
if top_p:
kwargs['top_p'] = top_p
if llm_client == "Run Ollama Locally":
if 'model_name' in kwargs:
kwargs['model'] = kwargs.pop('model_name')
return results, call_ollama(prompt, stream_flag=stream_flag, **kwargs)
elif isinstance(llm_client, OpenAI): # NVIDIA Client
return results, call_nvidia_llm(llm_client, prompt, stream_flag=stream_flag, **kwargs)
elif isinstance(llm_client, Together): # Together Client
return results, call_llm(llm_client, prompt, stream_flag=stream_flag, **kwargs)
else:
if llm_client is None:
raise ValueError("LLM client not initialized. Please check API keys.")
# Fallback to NVIDIA as requested
return results, call_nvidia_llm(llm_client, prompt, stream_flag=stream_flag, **kwargs)
def load_queries_and_answers(query_file, answers_file):
with open(query_file, 'r') as f:
queries = f.readlines()
with open(answers_file, 'r') as f:
answers = f.readlines()
return queries, answers
def write_batched_results(embedder_name, result_path):
launch_depression_assistant(embedder_name)
queries, answers = load_queries_and_answers("data/raw/queries.txt", "data/raw/answers.txt")
embedder_filename = embedder_name.replace('/', '_')
with open(f"{result_path}Retrieved_Results_by_{embedder_filename}.md", "w") as f1, \
open(f"{result_path}Response_by_{embedder_filename}.md", "w") as f2:
for i, query in enumerate(queries):
result, response = depression_assistant(query)
f1.write(f"## Query {i+1}\n{query.strip()}\n\n## Answer\n{answers[i].strip()}\n\n## Retrieved Results\n")
for res in result:
f1.write(f"\n\n#### {res['section']}\n\n{res['text']}\n")
f1.write("\n\n---\n\n")
f2.write(f"## Query {i+1}\n{query.strip()}\n\n## Answer\n{answers[i].strip()}\n\n## Response\n{response}\n\n---\n\n")
break # remove this `break` if you want to process all queries