qyle commited on
Commit
038efd6
·
verified ·
1 Parent(s): c1300fe

enable cuda and added pyinstrument

Browse files
Files changed (3) hide show
  1. main.py +3 -0
  2. rag.py +34 -0
  3. requirements.txt +1 -0
main.py CHANGED
@@ -1,5 +1,7 @@
1
  import os
2
  import asyncio
 
 
3
  from contextlib import asynccontextmanager
4
 
5
  from typing import AsyncGenerator, List, Optional, Tuple, Dict, Any
@@ -97,6 +99,7 @@ def convert_messages_langchain(messages: List[ChatMessage]):
97
  return list_chatmessages
98
 
99
 
 
100
  champ = ChampService(base_dir=BASE_DIR, hf_token=HF_TOKEN)
101
 
102
 
 
1
  import os
2
  import asyncio
3
+ import torch
4
+
5
  from contextlib import asynccontextmanager
6
 
7
  from typing import AsyncGenerator, List, Optional, Tuple, Dict, Any
 
99
  return list_chatmessages
100
 
101
 
102
+ print(f"CUDA available: {torch.cuda.is_available()}")
103
  champ = ChampService(base_dir=BASE_DIR, hf_token=HF_TOKEN)
104
 
105
 
rag.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/champ/rag.py
2
+ import torch
3
+
4
+ from pathlib import Path
5
+
6
+ from langchain_community.vectorstores import FAISS as LCFAISS
7
+ from langchain_huggingface import HuggingFaceEmbeddings
8
+
9
+ from constants import BASE_DIR, HF_TOKEN
10
+
11
+
12
+ def load_vector_store(
13
+ base_dir: Path = BASE_DIR,
14
+ hf_token: str = HF_TOKEN,
15
+ rag_relpath: str = "rag_data/FAISS_ALLEN_20260129",
16
+ embedding_model: str = "BAAI/bge-large-en-v1.5",
17
+ device: str = "cuda" if torch.cuda.is_available() else "cpu",
18
+ ) -> LCFAISS:
19
+ rag_path = base_dir / rag_relpath
20
+
21
+ model_embedding_kwargs = {"device": device, "use_auth_token": hf_token}
22
+ encode_kwargs = {"normalize_embeddings": True}
23
+
24
+ embeddings = HuggingFaceEmbeddings(
25
+ model_name=embedding_model,
26
+ model_kwargs=model_embedding_kwargs,
27
+ encode_kwargs=encode_kwargs,
28
+ )
29
+
30
+ return LCFAISS.load_local(
31
+ str(rag_path),
32
+ embeddings,
33
+ allow_dangerous_deserialization=True, # safe because you built the files
34
+ )
requirements.txt CHANGED
@@ -125,3 +125,4 @@ xxhash==3.6.0
125
  yarl==1.22.0
126
  zstandard==0.25.0
127
  pytz==2025.2
 
 
125
  yarl==1.22.0
126
  zstandard==0.25.0
127
  pytz==2025.2
128
+ pyinstrument==5.1.2