Wenye He commited on
Commit
2f7f89f
·
1 Parent(s): 678dc55

Upload 15 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,9 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ datas/bge_onnx/tokenizer.json filter=lfs diff=lfs merge=lfs -text
37
+ vector_stores/anatomical_regions_head_and_neck.duckdb filter=lfs diff=lfs merge=lfs -text
38
+ vector_stores/anatomical_regions_torso.duckdb filter=lfs diff=lfs merge=lfs -text
39
+ vector_stores/CFIR.duckdb filter=lfs diff=lfs merge=lfs -text
40
+ vector_stores/injury_typology_neurological_injuries.duckdb filter=lfs diff=lfs merge=lfs -text
41
+ vector_stores/injury_typology_soft_tissue_injuries.duckdb filter=lfs diff=lfs merge=lfs -text
Dockerfile ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Base image with CUDA 12.1 and Ubuntu 22.04
2
+ FROM nvidia/cuda:12.1.1-base-ubuntu22.04
3
+
4
+ # Install Python 3.10 and essential dependencies
5
+ RUN apt-get update && \
6
+ apt-get install -y --no-install-recommends \
7
+ python3.10 \
8
+ python3.10-dev \
9
+ python3.10-distutils \
10
+ curl \
11
+ git \
12
+ && rm -rf /var/lib/apt/lists/*
13
+
14
+ # Make Python 3.10 the default
15
+ RUN update-alternatives --install /usr/bin/python3 python3 /usr/bin/python3.10 1
16
+
17
+ # Install pip for Python 3.10
18
+ RUN curl -sS https://bootstrap.pypa.io/get-pip.py | python3.10
19
+
20
+ # Install Ollama with GPU layers
21
+ ENV OLLAMA_GPU_LAYERS=100
22
+ RUN curl -fsSL https://ollama.com/install.sh | sh
23
+
24
+ # Set up application directory
25
+ WORKDIR /app
26
+ COPY . .
27
+
28
+ # Install Python dependencies
29
+ RUN pip install --no-cache-dir -r requirements.txt
30
+
31
+ # Configure environment variables (FROM YOUR ORIGINAL SETUP)
32
+ ENV VECTOR_STORE_DIR=/app/vector_stores \
33
+ EMBED_MODEL_PATH=/app/datas/bge_onnx \
34
+ PYTHONUNBUFFERED=1 \
35
+ GRADIO_SERVER_NAME="0.0.0.0"
36
+
37
+ # Verify CUDA and Python versions
38
+ RUN python3 -c "import torch; print(f'PyTorch CUDA available: {torch.cuda.is_available()}')" && \
39
+ python3 --version
40
+
41
+ # Expose ports for Ollama and Gradio
42
+ EXPOSE 11434 7860
43
+
44
+ # Copy and set permissions for start script
45
+ COPY start.sh /app/start.sh
46
+ RUN chmod +x /app/start.sh
47
+
48
+ # Start services using the startup script
49
+ CMD ["/app/start.sh"]
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ from llama_index.embeddings.huggingface_optimum import OptimumEmbedding
3
+ import gradio as gr
4
+ from llama_index.core import Settings
5
+ from llama_index.core import VectorStoreIndex, StorageContext
6
+ from llama_index.vector_stores.duckdb import DuckDBVectorStore
7
+ from llama_index.llms.ollama import Ollama
8
+ from llama_index.core.memory import ChatMemoryBuffer
9
+ import json
10
+ import ollama
11
+ import os
12
+ import uuid
13
+
14
+
15
+
16
+ # Configuration
17
+ VECTOR_STORE_DIR = "./vector_stores"
18
+ EMBED_MODEL_PATH = "./datas/bge_onnx"
19
+ CONFIG_PATH = "config.json"
20
+
21
+ DEFAULT_LLM = "Jatin19K/unsloth-q5_k_m-mistral-nemo-instruct-2407"
22
+ DEFAULT_VECTOR_STORE = "CFIR"
23
+
24
+ class ModelManager:
25
+ def __init__(self):
26
+ self.config = self._load_config()
27
+ self.available_models = self._initialize_models()
28
+
29
+ def _load_config(self):
30
+ """Load model configuration from JSON file"""
31
+ try:
32
+ with open(CONFIG_PATH, 'r') as f:
33
+ return json.load(f)
34
+ except Exception as e:
35
+ print(f"Error loading config: {e}")
36
+ return {"models": []}
37
+
38
+ def _initialize_models(self):
39
+ """Initialize and verify all models from config"""
40
+ config_models = self.config.get("models", [])
41
+ available_models = {}
42
+
43
+ # Get currently available Ollama models
44
+ try:
45
+ current_models = {m['name'].split(':')[0]: m['name'] for m in ollama.list()['models']}
46
+ print(current_models)
47
+ except Exception as e:
48
+ print(f"Error fetching current models: {e}")
49
+ current_models = {}
50
+
51
+ # Check each configured model
52
+ for model_name in config_models:
53
+ if model_name not in current_models:
54
+ print(f"Model {model_name} not found locally. Attempting to pull...")
55
+ try:
56
+ ollama.pull(model_name)
57
+ available_models[model_name] = model_name
58
+ print(f"Successfully pulled model {model_name}")
59
+ except Exception as e:
60
+ print(f"Error pulling model {model_name}: {e}")
61
+ continue
62
+ else:
63
+ available_models[model_name] = current_models[model_name]
64
+
65
+ return available_models
66
+
67
+ def get_available_models(self):
68
+ """Return dictionary of available models"""
69
+ return self.available_models
70
+
71
+
72
+
73
+ class EmbeddingManager:
74
+ def __init__(self):
75
+ self.embed_model = None
76
+ self._initialize_embed_model()
77
+
78
+ def _initialize_embed_model(self):
79
+ """Initialize BGE ONNX embedding model with validation"""
80
+ try:
81
+ if not os.path.exists(EMBED_MODEL_PATH):
82
+ raise FileNotFoundError(f"BGE ONNX model not found at {EMBED_MODEL_PATH}")
83
+
84
+ self.embed_model = OptimumEmbedding(folder_name=EMBED_MODEL_PATH)
85
+ Settings.embed_model = self.embed_model
86
+ print("Successfully initialized BGE embedding model")
87
+
88
+ except Exception as e:
89
+ print(f"Embedding model error: {e}")
90
+
91
+
92
+ # Initialize managers
93
+ model_manager = ModelManager()
94
+ embed_manager = EmbeddingManager()
95
+
96
+ # def get_available_models():
97
+ # """Check locally available Ollama models"""
98
+ # try:
99
+ # models = ollama.list()['models']
100
+ # model_dict = {m['name'].split(':')[0]: m['name'] for m in models}
101
+
102
+ # # Create ordered list with default first
103
+ # ordered_models = {}
104
+ # if DEFAULT_LLM in model_dict:
105
+ # ordered_models[DEFAULT_LLM] = model_dict[DEFAULT_LLM]
106
+
107
+ # # Add remaining models alphabetically
108
+ # for name in sorted(model_dict.keys()):
109
+ # if name != DEFAULT_LLM:
110
+ # ordered_models[name] = model_dict[name]
111
+
112
+ # return ordered_models
113
+
114
+ # except Exception as e:
115
+ # print(f"Error fetching models: {e}")
116
+ # return {DEFAULT_LLM: DEFAULT_LLM} # Fallback
117
+
118
+ def get_available_vector_stores():
119
+ """Scan vector store directory for DuckDB files"""
120
+ vector_stores = {}
121
+ if os.path.exists(VECTOR_STORE_DIR):
122
+ cfir_path = os.path.join(VECTOR_STORE_DIR, f"{DEFAULT_VECTOR_STORE}.duckdb")
123
+ if os.path.exists(cfir_path):
124
+ vector_stores[DEFAULT_VECTOR_STORE] = {
125
+ "path": cfir_path,
126
+ "display_name": DEFAULT_VECTOR_STORE
127
+ }
128
+
129
+ # Add other stores
130
+ for file in os.listdir(VECTOR_STORE_DIR):
131
+ if file.endswith(".duckdb") and file != f"{DEFAULT_VECTOR_STORE}.duckdb":
132
+ store_name = file[:-7]
133
+ display_name = store_name.replace('_', ' ')
134
+ vector_stores[store_name] = {
135
+ "path": os.path.join(VECTOR_STORE_DIR, file),
136
+ "display_name": display_name
137
+ }
138
+ return vector_stores
139
+
140
+ class ChatSessionManager:
141
+ def __init__(self):
142
+ self.sessions = {}
143
+ self.llm_options = model_manager.get_available_models()
144
+ self.vector_stores = get_available_vector_stores()
145
+
146
+ def refresh_models(self):
147
+ self.llm_options = model_manager.get_available_models()
148
+
149
+ def refresh_vector_stores(self):
150
+ self.vector_stores = get_available_vector_stores()
151
+
152
+ def get_chat_engine(self, session_id, llm_choice, vector_store_choice):
153
+ """Create chat engine with configured embeddings"""
154
+ if session_id not in self.sessions:
155
+ # Verify vector store exists
156
+ if vector_store_choice not in self.vector_stores:
157
+ raise ValueError(f"Vector store {vector_store_choice} not found")
158
+
159
+ # Verify model exists
160
+ if llm_choice not in self.llm_options.values():
161
+ raise ValueError(f"Model {llm_choice} not available")
162
+
163
+ # Configure LLM
164
+ Settings.llm = Ollama(
165
+ model=llm_choice,
166
+ request_timeout=120,
167
+ temperature=0.3
168
+ )
169
+
170
+ # Load vector store
171
+ vs_path = self.vector_stores[vector_store_choice]["path"]
172
+ vector_store = DuckDBVectorStore.from_local(vs_path)
173
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
174
+
175
+ index = VectorStoreIndex.from_vector_store(
176
+ vector_store=vector_store,
177
+ storage_context=storage_context
178
+ )
179
+
180
+ memory = ChatMemoryBuffer.from_defaults()
181
+ self.sessions[session_id] = index.as_chat_engine(
182
+ chat_mode="context", # <-- Change chat mode
183
+ memory=memory, # <-- Add memory
184
+ system_prompt=(
185
+ "You are a helpful assistant which helps users to understand scientific knowledge"
186
+ "about biomechanics of injuries to human bodies."
187
+ ),
188
+ similarity_top_k=3
189
+ )
190
+
191
+ return self.sessions[session_id]
192
+
193
+ # Initialize session manager
194
+ session_manager = ChatSessionManager()
195
+
196
+ def chat_response(message, history, llm_choice, vector_store_choice, session_state):
197
+ try:
198
+ # Manage session state
199
+ if not session_state:
200
+ session_id = str(uuid.uuid4())
201
+ session_state = {"session_id": session_id}
202
+ else:
203
+ session_id = session_state["session_id"]
204
+
205
+ chat_engine = session_manager.get_chat_engine(session_id, llm_choice, vector_store_choice)
206
+ response = chat_engine.chat(message)
207
+
208
+ # Process response
209
+ sources = [
210
+ f"• {node.metadata.get('file_name', 'Unknown')}"
211
+ for node in response.source_nodes
212
+ ]
213
+
214
+ # bot_message = f"{response.response}\n\nSources:\n" + "\n".join(sources)
215
+ bot_message = f"{response.response}\n"
216
+ return history + [(message, bot_message)], session_state
217
+ # return history + [(message)], session_state
218
+
219
+ except Exception as e:
220
+ return history + [(message, f"Error: {str(e)}")], session_state
221
+
222
+ # Gradio interface with embedding status
223
+ with gr.Blocks(title="De-KCIB(Deep Knowledge Center for Injury Biomechanics)") as demo:
224
+
225
+ session_state = gr.State()
226
+
227
+ with gr.Row():
228
+ # gr.set_static_paths(paths=["static/logo.png"])
229
+
230
+ # gr.HTML("""
231
+ # <img src="/file=static/logo.png"
232
+ # alt="Company Logo"
233
+ # style="height: 100px; object-fit: contain;">
234
+ # """)
235
+ gr.HTML("<img src='https://www.ussbchamber.org/wp-content/uploads/2021/04/innovisionlogo.png' />")
236
+ # gr.Markdown("<img src='file/logo.png' alt='Company Logo' />")
237
+ with gr.Row():
238
+ gr.Markdown("# De-KCIB(Deep Knowledge Center for Injury Biomechanics)")
239
+
240
+ with gr.Row():
241
+ with gr.Column(scale=1):
242
+ llm_dropdown = gr.Dropdown(
243
+ label="Select Language Model",
244
+ choices=list(session_manager.llm_options.values()),
245
+ value=next(iter(session_manager.llm_options.values()), None)
246
+ )
247
+ vector_dropdown = gr.Dropdown(
248
+ label="Injury Biomechanics Knowledge Base",
249
+ choices=[(v["display_name"], k) for k, v in session_manager.vector_stores.items()],
250
+ value=next(iter(session_manager.vector_stores.keys()), None)
251
+ )
252
+ # refresh_btn = gr.Button("Refresh Resources")
253
+ # embed_status = gr.Markdown(
254
+ # f"**Embedding Model:** {embed_manager.embed_model.model_name}"
255
+ # if embed_manager.embed_model else
256
+ # "**Warning:** Using fallback embeddings"
257
+ # )
258
+
259
+ with gr.Column(scale=3):
260
+ chatbot = gr.Chatbot(height=500)
261
+ msg = gr.Textbox(label="Query")
262
+ clear_btn = gr.Button("Clear Session")
263
+
264
+ # # Event handlers
265
+ # refresh_btn.click(
266
+ # lambda: [
267
+ # session_manager.refresh_models(),
268
+ # session_manager.refresh_vector_stores()
269
+ # ],
270
+ # outputs=[llm_dropdown, vector_dropdown]
271
+ # )
272
+
273
+ msg.submit(
274
+ chat_response,
275
+ [msg, chatbot, llm_dropdown, vector_dropdown, session_state],
276
+ [chatbot, session_state] # <-- Update outputs
277
+ )
278
+
279
+ clear_btn.click(
280
+ lambda: (None, None), # Reset both chat and session
281
+ None,
282
+ [chatbot, session_state],
283
+ queue=False
284
+ )
285
+
286
+ # Deployment settings
287
+ if __name__ == "__main__":
288
+ demo.launch()
289
+ # demo.launch(share=True)
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ { "models": ["Jatin19K/unsloth-q5_k_m-mistral-nemo-instruct-2407", "Jatin19K/unsloth_q8_0_meta_llama_3.1_8b_instruct_bnb_4bit_innovision_dekcib"] }
datas/bge_onnx/config.json ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "Snowflake/snowflake-arctic-embed-l-v2.0",
3
+ "architectures": [
4
+ "XLMRobertaModel"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "bos_token_id": 0,
8
+ "classifier_dropout": null,
9
+ "eos_token_id": 2,
10
+ "hidden_act": "gelu",
11
+ "hidden_dropout_prob": 0.1,
12
+ "hidden_size": 1024,
13
+ "initializer_range": 0.02,
14
+ "intermediate_size": 4096,
15
+ "layer_norm_eps": 1e-05,
16
+ "max_position_embeddings": 8194,
17
+ "model_type": "xlm-roberta",
18
+ "num_attention_heads": 16,
19
+ "num_hidden_layers": 24,
20
+ "output_past": true,
21
+ "pad_token_id": 1,
22
+ "position_embedding_type": "absolute",
23
+ "torch_dtype": "float32",
24
+ "transformers_version": "4.46.3",
25
+ "type_vocab_size": 1,
26
+ "use_cache": true,
27
+ "vocab_size": 250002
28
+ }
datas/bge_onnx/special_tokens_map.json ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<s>",
4
+ "lstrip": false,
5
+ "normalized": false,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "cls_token": {
10
+ "content": "<s>",
11
+ "lstrip": false,
12
+ "normalized": false,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "eos_token": {
17
+ "content": "</s>",
18
+ "lstrip": false,
19
+ "normalized": false,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ },
23
+ "mask_token": {
24
+ "content": "<mask>",
25
+ "lstrip": true,
26
+ "normalized": false,
27
+ "rstrip": false,
28
+ "single_word": false
29
+ },
30
+ "pad_token": {
31
+ "content": "<pad>",
32
+ "lstrip": false,
33
+ "normalized": false,
34
+ "rstrip": false,
35
+ "single_word": false
36
+ },
37
+ "sep_token": {
38
+ "content": "</s>",
39
+ "lstrip": false,
40
+ "normalized": false,
41
+ "rstrip": false,
42
+ "single_word": false
43
+ },
44
+ "unk_token": {
45
+ "content": "<unk>",
46
+ "lstrip": false,
47
+ "normalized": false,
48
+ "rstrip": false,
49
+ "single_word": false
50
+ }
51
+ }
datas/bge_onnx/tokenizer_config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "added_tokens_decoder": {
3
+ "0": {
4
+ "content": "<s>",
5
+ "lstrip": false,
6
+ "normalized": false,
7
+ "rstrip": false,
8
+ "single_word": false,
9
+ "special": true
10
+ },
11
+ "1": {
12
+ "content": "<pad>",
13
+ "lstrip": false,
14
+ "normalized": false,
15
+ "rstrip": false,
16
+ "single_word": false,
17
+ "special": true
18
+ },
19
+ "2": {
20
+ "content": "</s>",
21
+ "lstrip": false,
22
+ "normalized": false,
23
+ "rstrip": false,
24
+ "single_word": false,
25
+ "special": true
26
+ },
27
+ "3": {
28
+ "content": "<unk>",
29
+ "lstrip": false,
30
+ "normalized": false,
31
+ "rstrip": false,
32
+ "single_word": false,
33
+ "special": true
34
+ },
35
+ "250001": {
36
+ "content": "<mask>",
37
+ "lstrip": true,
38
+ "normalized": false,
39
+ "rstrip": false,
40
+ "single_word": false,
41
+ "special": true
42
+ }
43
+ },
44
+ "bos_token": "<s>",
45
+ "clean_up_tokenization_spaces": true,
46
+ "cls_token": "<s>",
47
+ "eos_token": "</s>",
48
+ "mask_token": "<mask>",
49
+ "max_length": 512,
50
+ "model_max_length": 8192,
51
+ "pad_to_multiple_of": null,
52
+ "pad_token": "<pad>",
53
+ "pad_token_type_id": 0,
54
+ "padding_side": "right",
55
+ "sep_token": "</s>",
56
+ "stride": 0,
57
+ "tokenizer_class": "XLMRobertaTokenizer",
58
+ "truncation_side": "right",
59
+ "truncation_strategy": "longest_first",
60
+ "unk_token": "<unk>"
61
+ }
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ --extra-index-url https://download.pytorch.org/whl/cu121
2
+ numpy==1.26.4
3
+ ollama==0.3.3
4
+ onnx==1.17.0
5
+ gradio==5.16.0
6
+ ollama
7
+ llama-index-core
8
+ llama-index-embeddings-huggingface-optimum
9
+ llama-index-llms-ollama
10
+ llama-index-vector-stores-duckdb
11
+ duckdb
12
+ torch==2.5.0+cu121
start.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/bin/sh
2
+ ollama serve > /dev/null 2>&1 &
3
+ sleep 10 && python3 app.py