sergey21000 commited on
Commit
ac7d372
·
verified ·
1 Parent(s): 79fb1e4

Upload 2 files

Browse files
Files changed (2) hide show
  1. Dockerfile +13 -0
  2. config.py +266 -0
Dockerfile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ghcr.io/sergey21000/chatbot-rag:main-cpu
2
+
3
+ RUN useradd -m -u 1000 user \
4
+ && chown -R user:user /app
5
+
6
+ USER user
7
+ ENV HOME=/home/user \
8
+ PATH=/home/user/.local/bin:$PATH
9
+
10
+ WORKDIR /app
11
+ COPY --chown=user config.py ./
12
+
13
+ CMD ["python3", "app.py"]
config.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Any, ClassVar
5
+
6
+ import gradio as gr
7
+ import torch
8
+ from llama_cpp import Llama
9
+ from chromadb import EmbeddingFunction
10
+
11
+ from dotenv import load_dotenv
12
+ load_dotenv()
13
+
14
+
15
+ class ModelStorage:
16
+ '''Global model storage'''
17
+ LLM_MODEL: ClassVar[dict[str, Llama]] = {}
18
+ EMBED_MODEL: ClassVar[dict[str, EmbeddingFunction]] = {}
19
+
20
+
21
+ class UiBlocksConfig:
22
+ '''Gradio settings for gr.Blocks()'''
23
+ CSS: str | None = '''
24
+ .gradio-container {
25
+ width: 70% !important;
26
+ margin: 0 auto !important;
27
+ }
28
+ '''
29
+ if hasattr(sys, 'getandroidapilevel') or 'ANDROID_ROOT' in os.environ:
30
+ CSS = None
31
+ UI_BLOCKS_KWARGS: dict[str, Any] = dict(
32
+ theme=None,
33
+ css=CSS,
34
+ analytics_enabled=False,
35
+ )
36
+
37
+
38
+ class InferenceConfig:
39
+ '''Model inference settings'''
40
+ def __init__(self):
41
+ self.encode_kwargs: dict[str, Any] = dict(
42
+ batch_size=300,
43
+ normalize_embeddings=None,
44
+ )
45
+ self.sampling_kwargs: dict[str, Any] = dict(
46
+ temperature=0.2,
47
+ top_p=0.95,
48
+ top_k=40,
49
+ repeat_penalty=1.0,
50
+ )
51
+ self.do_sample: bool = False
52
+ self.rag_mode: bool = False
53
+ self.history_len: int = 0
54
+ self.show_thinking: bool = False
55
+
56
+
57
+ class TextLoadConfig:
58
+ '''Settings for loading texts from documents'''
59
+ def __init__(self):
60
+ self.partition_kwargs: dict[str, str | int | bool | None] = dict(
61
+ chunking_strategy='basic',
62
+ max_characters=800,
63
+ new_after_n_chars=500,
64
+ overlap=0,
65
+ clean=True,
66
+ bullets=True,
67
+ extra_whitespace=True,
68
+ dashes=False,
69
+ trailing_punctuation=True,
70
+ lowercase=False,
71
+ )
72
+ self.SUPPORTED_FILE_EXTS: str = '.csv .tsv .docx .md .org .pdf .pptx .xlsx'
73
+ self.subtitle_lang: str = 'ru'
74
+ self.SUBTITLE_LANGS: list[str] = ['ru', 'en']
75
+ self.max_lines_text_view: int = 200
76
+
77
+
78
+ class DbConfig:
79
+ '''Vector database parameters (Chroma)'''
80
+ def __init__(self):
81
+ self.create_collection_kwargs: dict[str, Any] = dict(
82
+ configuration=dict(
83
+ hnsw=dict(
84
+ space='cosine', # l2, ip, cosine, default l2
85
+ ef_construction=200,
86
+ )
87
+ )
88
+ )
89
+ self.query_kwargs: dict[str, Any] = dict(
90
+ n_results=2,
91
+ max_distance_treshold=0.5,
92
+ )
93
+
94
+
95
+ class PromptConfig:
96
+ '''Prompts'''
97
+ def __init__(self):
98
+ self.system_prompt: str | None = None
99
+ self.user_msg_with_context: str = ''
100
+ self.context_template: str = '''Ответь на вопрос при условии контекста.
101
+
102
+ Контекст:
103
+ {context}
104
+
105
+ Вопрос:
106
+ {user_message}
107
+
108
+ Ответ:'''
109
+
110
+
111
+ class ModelConfig:
112
+ '''Configuration of paths, models and generation parameters'''
113
+ def __init__(self):
114
+ self.LLM_MODELS_PATH: Path = Path('models')
115
+ self.EMBED_MODELS_PATH: Path = Path('embed_models')
116
+ self.LLM_MODELS_PATH.mkdir(exist_ok=True)
117
+ self.EMBED_MODELS_PATH.mkdir(exist_ok=True)
118
+ self.llm_model_repo: str = 'bartowski/google_gemma-3-1b-it-GGUF'
119
+ self.llm_model_file: str = 'google_gemma-3-1b-it-Q8_0.gguf'
120
+ self.embed_model_repo: str = 'Alibaba-NLP/gte-multilingual-base'
121
+ self.embed_model_kwargs: dict[str, Any] = dict(
122
+ device='cuda:0',
123
+ trust_remote_code=True,
124
+ cache_folder=self.EMBED_MODELS_PATH,
125
+ token=os.getenv('HF_TOKEN'),
126
+ model_kwargs=dict(
127
+ torch_dtype='auto',
128
+ )
129
+ )
130
+ self.llm_model_kwargs: dict[str, Any] = dict(
131
+ n_gpu_layers=-1,
132
+ n_ctx=4096,
133
+ verbose=False,
134
+ local_dir=self.LLM_MODELS_PATH,
135
+ )
136
+
137
+
138
+ class ReposConfig:
139
+ '''Links to repositories with ggu models'''
140
+ def __init__(self):
141
+ self.llm_model_repos: list[str] = [
142
+ 'bartowski/google_gemma-3-1b-it-GGUF',
143
+ 'bartowski/google_gemma-3-4b-it-GGUF',
144
+ 'bartowski/Qwen_Qwen3-1.7B-GGUF',
145
+ 'bartowski/Qwen_Qwen3-4B-GGUF',
146
+ ]
147
+ self.embed_model_repos: list[str] = [
148
+ 'Alibaba-NLP/gte-multilingual-base',
149
+ 'sergeyzh/rubert-tiny-turbo',
150
+ 'intfloat/multilingual-e5-large',
151
+ 'intfloat/multilingual-e5-base',
152
+ 'intfloat/multilingual-e5-small',
153
+ 'intfloat/multilingual-e5-large-instruct',
154
+ 'sentence-transformers/all-mpnet-base-v2',
155
+ 'sentence-transformers/paraphrase-multilingual-mpnet-base-v2',
156
+ 'ai-forever/ruElectra-medium',
157
+ 'ai-forever/sbert_large_nlu_ru',
158
+ 'deepvk/USER2-small',
159
+ 'BAAI/bge-m3-retromae',
160
+ ]
161
+
162
+
163
+ class Config:
164
+ '''General config'''
165
+ def __init__(self):
166
+ self.Inference: InferenceConfig = InferenceConfig()
167
+ self.TextLoad: TextLoadConfig = TextLoadConfig()
168
+ self.Prompt: PromptConfig = PromptConfig()
169
+ self.Db: DbConfig = DbConfig()
170
+ self.Model: ModelConfig = ModelConfig()
171
+ self.Repos: ReposConfig = ReposConfig()
172
+
173
+ self.generation_kwargs: dict[str, Any] = dict(
174
+ do_sample=self.Inference.do_sample,
175
+ temperature=self.Inference.sampling_kwargs['temperature'],
176
+ top_p=self.Inference.sampling_kwargs['top_p'],
177
+ top_k=self.Inference.sampling_kwargs['top_k'],
178
+ repeat_penalty=self.Inference.sampling_kwargs['repeat_penalty'],
179
+ history_len=self.Inference.history_len,
180
+ system_prompt=self.Prompt.system_prompt,
181
+ context_template=self.Prompt.context_template,
182
+ show_thinking=self.Inference.show_thinking,
183
+ n_results=self.Db.query_kwargs['n_results'],
184
+ max_distance_treshold=self.Db.query_kwargs['max_distance_treshold'],
185
+ user_msg_with_context=self.Prompt.user_msg_with_context,
186
+ rag_mode=self.Inference.rag_mode,
187
+ )
188
+ self.load_text_kwargs: dict[str, Any] = dict(
189
+ chunking_strategy=self.TextLoad.partition_kwargs['chunking_strategy'],
190
+ max_characters=self.TextLoad.partition_kwargs['max_characters'],
191
+ new_after_n_chars=self.TextLoad.partition_kwargs['new_after_n_chars'],
192
+ overlap=self.TextLoad.partition_kwargs['overlap'],
193
+ clean=self.TextLoad.partition_kwargs['clean'],
194
+ bullets=self.TextLoad.partition_kwargs['bullets'],
195
+ extra_whitespace=self.TextLoad.partition_kwargs['extra_whitespace'],
196
+ dashes=self.TextLoad.partition_kwargs['dashes'],
197
+ trailing_punctuation=self.TextLoad.partition_kwargs['trailing_punctuation'],
198
+ lowercase=self.TextLoad.partition_kwargs['lowercase'],
199
+ subtitle_lang=self.TextLoad.subtitle_lang,
200
+ )
201
+ self.load_model_kwargs: dict[str, Any] = dict(
202
+ llm_model_repo=self.Model.llm_model_repo,
203
+ llm_model_file=self.Model.llm_model_file,
204
+ embed_model_repo=self.Model.embed_model_repo,
205
+ n_gpu_layers=self.Model.llm_model_kwargs['n_gpu_layers'],
206
+ n_ctx=self.Model.llm_model_kwargs['n_ctx'],
207
+ )
208
+ self.view_text_kwargs: dict[str, Any] = dict(
209
+ max_lines_text_view=self.TextLoad.max_lines_text_view,
210
+ )
211
+
212
+ def get_sampling_kwargs(self) -> dict[str, Any]:
213
+ return dict(
214
+ temperature=self.generation_kwargs['temperature'],
215
+ top_p=self.generation_kwargs['top_p'],
216
+ top_k=self.generation_kwargs['top_k'],
217
+ repeat_penalty=self.generation_kwargs['repeat_penalty'],
218
+ )
219
+ def get_rag_kwargs(self) -> dict[str, Any]:
220
+ return dict(
221
+ n_results=self.generation_kwargs['n_results'],
222
+ max_distance_treshold=self.generation_kwargs['max_distance_treshold'],
223
+ user_msg_with_context=self.generation_kwargs['user_msg_with_context'],
224
+ context_template=self.generation_kwargs['context_template'],
225
+ )
226
+ def get_partition_kwargs(self) -> dict[str, Any]:
227
+ return dict(
228
+ chunking_strategy=self.load_text_kwargs['chunking_strategy'],
229
+ max_characters=self.load_text_kwargs['max_characters'],
230
+ new_after_n_chars=self.load_text_kwargs['new_after_n_chars'],
231
+ overlap=self.load_text_kwargs['overlap'],
232
+ clean=self.load_text_kwargs['clean'],
233
+ bullets=self.load_text_kwargs['bullets'],
234
+ extra_whitespace=self.load_text_kwargs['extra_whitespace'],
235
+ dashes=self.load_text_kwargs['dashes'],
236
+ trailing_punctuation=self.load_text_kwargs['trailing_punctuation'],
237
+ lowercase=self.load_text_kwargs['lowercase'],
238
+ )
239
+ def get_clean_kwargs(self) -> dict[str, Any]:
240
+ return dict(
241
+ bullets=self.load_text_kwargs['bullets'],
242
+ extra_whitespace=self.load_text_kwargs['extra_whitespace'],
243
+ dashes=self.load_text_kwargs['dashes'],
244
+ trailing_punctuation=self.load_text_kwargs['trailing_punctuation'],
245
+ lowercase=self.load_text_kwargs['lowercase'],
246
+ )
247
+ def get_chunking_kwargs(self):
248
+ return dict(
249
+ max_characters=self.load_text_kwargs['max_characters'],
250
+ new_after_n_chars=self.load_text_kwargs['new_after_n_chars'],
251
+ overlap=self.load_text_kwargs['overlap'],
252
+ )
253
+ def get_embed_model_kwargs(self) -> dict[str, Any]:
254
+ return self.Model.embed_model_kwargs
255
+
256
+ def get_encode_kwargs(self) -> dict[str, Any]:
257
+ return self.Inference.encode_kwargs
258
+
259
+ def get_llm_model_kwargs(self) -> dict[str, Any]:
260
+ return self.Model.llm_model_kwargs
261
+
262
+ def get_query_kwargs(self) -> dict[str, Any]:
263
+ return dict(
264
+ n_results=self.generation_kwargs['n_results'],
265
+ max_distance_treshold=self.generation_kwargs['max_distance_treshold'],
266
+ )