Rulga commited on
Commit
d45acb2
·
1 Parent(s): cd4842a

Refactor generation settings in the chat interface and update embedding model configuration

Browse files
Files changed (3) hide show
  1. app.py +16 -60
  2. config/settings.py +1 -1
  3. src/training/model_manager.py +8 -8
app.py CHANGED
@@ -152,7 +152,7 @@ with gr.Blocks() as demo:
152
  chatbot = gr.Chatbot(
153
  label="Chat",
154
  bubble_full_width=False,
155
- avatar_images=["user.png", "assistant.png"] # optional
156
  )
157
 
158
  with gr.Row():
@@ -167,68 +167,11 @@ with gr.Blocks() as demo:
167
  gr.Markdown("### Knowledge Base Management")
168
  build_kb_btn = gr.Button("Create/Update Knowledge Base", variant="primary")
169
  kb_status = gr.Textbox(label="Knowledge Base Status", interactive=False)
170
-
171
- gr.Markdown("### Generation Settings")
172
- max_tokens = gr.Slider(
173
- minimum=1,
174
- maximum=2048,
175
- value=512,
176
- step=1,
177
- label="Maximum Response Length",
178
- info="Limits the number of tokens in response. More tokens = longer response"
179
- )
180
- temperature = gr.Slider(
181
- minimum=0.1,
182
- maximum=2.0,
183
- value=0.7,
184
- step=0.1,
185
- label="Temperature",
186
- info="Controls creativity. Lower value = more predictable responses"
187
- )
188
- top_p = gr.Slider(
189
- minimum=0.1,
190
- maximum=1.0,
191
- value=0.95,
192
- step=0.05,
193
- label="Top-p",
194
- info="Controls diversity. Lower value = more focused responses"
195
- )
196
-
197
- clear_btn = gr.Button("Clear Chat History")
198
-
199
- def respond_and_clear(
200
- message,
201
- history,
202
- conversation_id,
203
- max_tokens,
204
- temperature,
205
- top_p,
206
- ):
207
- # Use existing respond function
208
- response_generator = respond(
209
- message,
210
- history,
211
- conversation_id,
212
- DEFAULT_SYSTEM_MESSAGE,
213
- max_tokens,
214
- temperature,
215
- top_p,
216
- )
217
-
218
- # Return result and empty string to clear input field
219
- for response in response_generator:
220
- yield response[0], response[1], "" # chatbot, conversation_id, empty string for msg
221
 
222
- # Event handlers
223
- msg.submit(
224
- respond_and_clear,
225
- [msg, chatbot, conversation_id, max_tokens, temperature, top_p],
226
- [chatbot, conversation_id, msg] # Add msg to output parameters
227
- )
228
  submit_btn.click(
229
  respond_and_clear,
230
- [msg, chatbot, conversation_id, max_tokens, temperature, top_p],
231
- [chatbot, conversation_id, msg] # Add msg to output parameters
232
  )
233
  build_kb_btn.click(build_kb, None, kb_status)
234
  clear_btn.click(lambda: ([], None), None, [chatbot, conversation_id])
@@ -247,6 +190,9 @@ with gr.Blocks() as demo:
247
  **Description:** {MODEL_CONFIG['description']}
248
 
249
  **Type:** {MODEL_CONFIG['type']}
 
 
 
250
  """)
251
 
252
  gr.Markdown("### Model Parameters")
@@ -284,6 +230,16 @@ with gr.Blocks() as demo:
284
  label="Repetition Penalty",
285
  interactive=False
286
  )
 
 
 
 
 
 
 
 
 
 
287
 
288
  with gr.Column(scale=1):
289
  gr.Markdown("### Training Configuration")
 
152
  chatbot = gr.Chatbot(
153
  label="Chat",
154
  bubble_full_width=False,
155
+ avatar_images=["user.png", "assistant.png"]
156
  )
157
 
158
  with gr.Row():
 
167
  gr.Markdown("### Knowledge Base Management")
168
  build_kb_btn = gr.Button("Create/Update Knowledge Base", variant="primary")
169
  kb_status = gr.Textbox(label="Knowledge Base Status", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
 
 
 
 
 
 
171
  submit_btn.click(
172
  respond_and_clear,
173
+ [msg, chatbot, conversation_id], # Remove generation parameters
174
+ [chatbot, conversation_id, msg]
175
  )
176
  build_kb_btn.click(build_kb, None, kb_status)
177
  clear_btn.click(lambda: ([], None), None, [chatbot, conversation_id])
 
190
  **Description:** {MODEL_CONFIG['description']}
191
 
192
  **Type:** {MODEL_CONFIG['type']}
193
+
194
+ **Embeddings Model:** `{EMBEDDING_MODEL}`
195
+ *Used for vector store creation and similarity search*
196
  """)
197
 
198
  gr.Markdown("### Model Parameters")
 
230
  label="Repetition Penalty",
231
  interactive=False
232
  )
233
+
234
+ gr.Markdown("""
235
+ <small>
236
+ **Parameters explanation:**
237
+ - **Maximum Length**: Maximum number of tokens in the generated response
238
+ - **Temperature**: Controls randomness (0.1 = very focused, 2.0 = very creative)
239
+ - **Top-p**: Controls diversity via nucleus sampling (lower = more focused)
240
+ - **Repetition Penalty**: Prevents word repetition (higher = less repetition)
241
+ </small>
242
+ """)
243
 
244
  with gr.Column(scale=1):
245
  gr.Markdown("### Training Configuration")
config/settings.py CHANGED
@@ -36,7 +36,7 @@ MODEL_CONFIG = {
36
  }
37
 
38
  # Embedding model for vector store
39
- EMBEDDING_MODEL = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2"
40
 
41
  # Request settings
42
  USER_AGENT = "Status-Law-Assistant/1.0"
 
36
  }
37
 
38
  # Embedding model for vector store
39
+ EMBEDDING_MODEL = "intfloat/multilingual-e5-large"
40
 
41
  # Request settings
42
  USER_AGENT = "Status-Law-Assistant/1.0"
src/training/model_manager.py CHANGED
@@ -204,21 +204,21 @@ def get_model(
204
  raise ValueError(f"Failed to load model: {str(e)}")
205
 
206
  if __name__ == "__main__":
207
- # Пример использования
208
  manager = ModelManager()
209
 
210
- # Регистрация базовой модели
211
  success, message = manager.register_model(
212
- model_id="saiga",
213
- version="7b",
214
- source="hf://IlyaGusev/saiga_7b_lora",
215
- description="Базовая модель Saiga 7B с LoRA адаптерами",
216
  is_active=True
217
  )
218
  print(message)
219
 
220
- # Вывод списка моделей
221
  models = manager.list_models()
222
- print(f"В реестре {len(models)} моделей:")
223
  for model in models:
224
  print(f" - {model['model_id']} v{model['version']}: {model['description']}")
 
204
  raise ValueError(f"Failed to load model: {str(e)}")
205
 
206
  if __name__ == "__main__":
207
+ # Usage example
208
  manager = ModelManager()
209
 
210
+ # Register base model from config
211
  success, message = manager.register_model(
212
+ model_id=MODEL_CONFIG["id"].split("/")[-1], # Extract model name from full HF path
213
+ version=MODEL_CONFIG["type"],
214
+ source=MODEL_CONFIG["id"],
215
+ description=MODEL_CONFIG["description"],
216
  is_active=True
217
  )
218
  print(message)
219
 
220
+ # Print models list
221
  models = manager.list_models()
222
+ print(f"Registry contains {len(models)} models:")
223
  for model in models:
224
  print(f" - {model['model_id']} v{model['version']}: {model['description']}")