Dongjin1203 commited on
Commit
fca8e5d
ยท
1 Parent(s): 856f2a5

Test GGUF with lightweight build

Browse files
Files changed (1) hide show
  1. src/generator/generator_gguf.py +56 -62
src/generator/generator_gguf.py CHANGED
@@ -1,4 +1,4 @@
1
- from llama_cpp import Llama
2
  from typing import Optional, Dict, Any, List
3
  import logging
4
  import time
@@ -152,7 +152,7 @@ class GGUFGenerator:
152
  logger.warning("โš ๏ธ system_prompt๊ฐ€ None! ๊ธฐ๋ณธ ํ”„๋กฌํ”„ํŠธ ์‚ฌ์šฉ")
153
  else:
154
  # ๋™์  ํ”„๋กฌํ”„ํŠธ ๋ฏธ๋ฆฌ๋ณด๊ธฐ (์ฒ˜์Œ 150์ž๋งŒ)
155
- logger.info(f"โœ… ๋™์  ํ”„๋กฌํ”„ํŠธ ์ ์šฉ:\n{system_prompt[:150]}...") # โ† ์ถ”๊ฐ€
156
 
157
  # ์ปจํ…์ŠคํŠธ ํฌํ•จ ์—ฌ๋ถ€
158
  if context is not None:
@@ -246,92 +246,86 @@ class GGUFGenerator:
246
 
247
  Args:
248
  question: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
249
- context: ์„ ํƒ์  ์ปจํ…์ŠคํŠธ (RAG ๊ฒฐ๊ณผ)
250
- **kwargs: generate() ํŒŒ๋ผ๋ฏธํ„ฐ
 
251
 
252
  Returns:
253
  ์ƒ์„ฑ๋œ ์‘๋‹ต
254
  """
255
- # ํ”„๋กฌํ”„ํŠธ ํฌ๋งทํŒ… (system_prompt ์ „๋‹ฌ)
256
- formatted_prompt = self.format_prompt(
257
  question=question,
258
  context=context,
259
- system_prompt=system_prompt # โ† ์ถ”๊ฐ€!
260
  )
261
 
262
  # ์‘๋‹ต ์ƒ์„ฑ
263
- response = self.generate(formatted_prompt, **kwargs)
264
-
265
- return response
266
-
267
- def get_model_info(self) -> Dict[str, Any]:
268
- """
269
- ๋ชจ๋ธ ์ •๋ณด ๋ฐ˜ํ™˜
270
 
271
- Returns:
272
- ๋ชจ๋ธ ์ •๋ณด ๋”•์…”๋„ˆ๋ฆฌ
273
- """
274
- info = {
275
- "model_path": self.model_path,
276
- "n_gpu_layers": self.n_gpu_layers,
277
- "n_ctx": self.n_ctx,
278
- "n_threads": self.n_threads,
279
- "is_loaded": self.model is not None,
280
- "max_new_tokens": self.max_new_tokens,
281
- "temperature": self.temperature,
282
- "top_p": self.top_p,
283
- }
284
-
285
- return info
286
-
287
- def __repr__(self):
288
- return f"GGUFGenerator(model={self.model_path}, loaded={self.model is not None})"
289
-
290
 
291
- # ===== GGUF RAGPipeline: chatbot_app.py ํ˜ธํ™˜์šฉ =====
292
 
293
  class GGUFRAGPipeline:
294
  """
295
- GGUF ๋ชจ๋ธ ๊ธฐ๋ฐ˜ RAG ํŒŒ์ดํ”„๋ผ์ธ
296
 
297
- RAGPipeline(API ๋ฒ„์ „)๊ณผ ๋™์ผํ•œ ์ธํ„ฐํŽ˜์ด์Šค๋ฅผ ์ œ๊ณตํ•˜์—ฌ
298
- chatbot_app.py์™€ ํ˜ธํ™˜๋ฉ๋‹ˆ๋‹ค.
299
  """
300
 
301
- def __init__(self, config=None, model: str = None, top_k: int = None):
 
 
 
 
 
 
 
 
 
 
 
 
 
302
  """
303
  ์ดˆ๊ธฐํ™”
304
 
305
  Args:
306
- config: RAGConfig ๊ฐ์ฒด
307
- model: ๋ชจ๋ธ ์ด๋ฆ„ (์‚ฌ์šฉ ์•ˆ ํ•จ, ํ˜ธํ™˜์„ฑ์šฉ)
308
- top_k: ๊ธฐ๋ณธ ๊ฒ€์ƒ‰ ๋ฌธ์„œ ์ˆ˜
 
 
 
 
 
 
 
309
  """
310
- # Config import (์ง€์—ฐ import๋กœ ์ˆœํ™˜ ์ฐธ์กฐ ๋ฐฉ์ง€)
311
- from src.utils.config import RAGConfig
312
- from src.retriever.retriever import RAGRetriever
313
-
314
  self.config = config or RAGConfig()
315
- self.top_k = top_k or self.config.DEFAULT_TOP_K
316
-
317
- # ๊ฒ€์ƒ‰ ์„ค์ •
318
- self.search_mode = self.config.DEFAULT_SEARCH_MODE
319
- self.alpha = self.config.DEFAULT_ALPHA
320
 
321
  # Retriever ์ดˆ๊ธฐํ™”
322
- logger.info("RAGRetriever ์ดˆ๊ธฐํ™” ์ค‘...")
323
- self.retriever = RAGRetriever(config=self.config)
 
 
 
 
 
324
 
325
- # GGUFGenerator ์ดˆ๊ธฐํ™”
326
- logger.info("GGUFGenerator ์ดˆ๊ธฐํ™” ์ค‘...")
327
  self.generator = GGUFGenerator(
328
  model_path=self.config.GGUF_MODEL_PATH,
329
- n_gpu_layers=self.config.GGUF_N_GPU_LAYERS,
330
- n_ctx=self.config.GGUF_N_CTX,
331
- n_threads=self.config.GGUF_N_THREADS,
332
- max_new_tokens=self.config.GGUF_MAX_NEW_TOKENS,
333
- temperature=self.config.GGUF_TEMPERATURE,
334
- top_p=self.config.GGUF_TOP_P,
 
335
  system_prompt=self.config.SYSTEM_PROMPT
336
  )
337
 
@@ -487,7 +481,7 @@ class GGUFRAGPipeline:
487
  answer = self.generator.chat(
488
  question=query,
489
  context=context,
490
- system_prompt=system_prompt # โ† ์ถ”๊ฐ€!
491
  )
492
 
493
  elapsed_time = time.time() - start_time
@@ -501,7 +495,7 @@ class GGUFRAGPipeline:
501
  'answer': answer,
502
  'sources': self._format_sources(self._last_retrieved_docs),
503
  'used_retrieval': used_retrieval,
504
- 'query_type': query_type, # โ† ์ถ”๊ฐ€!
505
  'search_mode': self.search_mode if used_retrieval else 'direct',
506
  'routing_info': classification,
507
  'elapsed_time': elapsed_time,
 
1
+ from llama_cpp import Llama # โ† ์ฃผ์„ ํ•ด์ œ!
2
  from typing import Optional, Dict, Any, List
3
  import logging
4
  import time
 
152
  logger.warning("โš ๏ธ system_prompt๊ฐ€ None! ๊ธฐ๋ณธ ํ”„๋กฌํ”„ํŠธ ์‚ฌ์šฉ")
153
  else:
154
  # ๋™์  ํ”„๋กฌํ”„ํŠธ ๋ฏธ๋ฆฌ๋ณด๊ธฐ (์ฒ˜์Œ 150์ž๋งŒ)
155
+ logger.info(f"โœ… ๋™์  ํ”„๋กฌํ”„ํŠธ ์ ์šฉ:\n{system_prompt[:150]}...")
156
 
157
  # ์ปจํ…์ŠคํŠธ ํฌํ•จ ์—ฌ๋ถ€
158
  if context is not None:
 
246
 
247
  Args:
248
  question: ์‚ฌ์šฉ์ž ์งˆ๋ฌธ
249
+ context: ์„ ํƒ์  ์ปจํ…์ŠคํŠธ
250
+ system_prompt: ์„ ํƒ์  ์‹œ์Šคํ…œ ํ”„๋กฌํ”„ํŠธ
251
+ **kwargs: generate() ๋ฉ”์„œ๋“œ์— ์ „๋‹ฌ๋  ์ถ”๊ฐ€ ํŒŒ๋ผ๋ฏธํ„ฐ
252
 
253
  Returns:
254
  ์ƒ์„ฑ๋œ ์‘๋‹ต
255
  """
256
+ # ํ”„๋กฌํ”„ํŠธ ํฌ๋งทํŒ…
257
+ prompt = self.format_prompt(
258
  question=question,
259
  context=context,
260
+ system_prompt=system_prompt
261
  )
262
 
263
  # ์‘๋‹ต ์ƒ์„ฑ
264
+ response = self.generate(prompt, **kwargs)
 
 
 
 
 
 
265
 
266
+ return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
 
268
 
269
  class GGUFRAGPipeline:
270
  """
271
+ GGUF ์ƒ์„ฑ๊ธฐ + RAG ํ†ตํ•ฉ ํŒŒ์ดํ”„๋ผ์ธ
272
 
273
+ chatbot_app.py์™€ ํ˜ธํ™˜๋˜๋Š” ์ธํ„ฐํŽ˜์ด์Šค ์ œ๊ณต
 
274
  """
275
 
276
+ def __init__(
277
+ self,
278
+ config: RAGConfig = None,
279
+ model: str = None, # ํ˜ธํ™˜์„ฑ์šฉ (์‚ฌ์šฉ ์•ˆ ํ•จ)
280
+ top_k: int = 10,
281
+ n_gpu_layers: int = 0, # GPU ๋ ˆ์ด์–ด ์ˆ˜
282
+ n_ctx: int = 2048,
283
+ n_threads: int = 8,
284
+ max_new_tokens: int = 256,
285
+ temperature: float = 0.7,
286
+ top_p: float = 0.9,
287
+ search_mode: str = "hybrid_rerank",
288
+ alpha: float = 0.5
289
+ ):
290
  """
291
  ์ดˆ๊ธฐํ™”
292
 
293
  Args:
294
+ config: RAGConfig ์ธ์Šคํ„ด์Šค
295
+ n_gpu_layers: GPU ๋ ˆ์ด์–ด ์ˆ˜
296
+ n_ctx: ์ปจํ…์ŠคํŠธ ๊ธธ์ด
297
+ n_threads: CPU ์Šค๋ ˆ๋“œ ์ˆ˜
298
+ max_new_tokens: ์ตœ๋Œ€ ์ƒ์„ฑ ํ† ํฐ
299
+ temperature: ์ƒ์„ฑ ๋‹ค์–‘์„ฑ
300
+ top_p: Nucleus sampling
301
+ search_mode: ๊ฒ€์ƒ‰ ๋ชจ๋“œ
302
+ top_k: ๊ฒ€์ƒ‰ํ•  ๋ฌธ์„œ ์ˆ˜
303
+ alpha: ์ž„๋ฒ ๋”ฉ ๊ฐ€์ค‘์น˜
304
  """
 
 
 
 
305
  self.config = config or RAGConfig()
306
+ self.search_mode = search_mode
307
+ self.top_k = top_k
308
+ self.alpha = alpha
 
 
309
 
310
  # Retriever ์ดˆ๊ธฐํ™”
311
+ from src.retriever.hybrid_retriever import HybridRetriever
312
+ self.retriever = HybridRetriever(
313
+ collection_name=self.config.COLLECTION_NAME,
314
+ persist_directory=self.config.CHROMA_DB_DIR,
315
+ embedding_model_name=self.config.EMBEDDING_MODEL,
316
+ reranker_model_name=self.config.RERANKER_MODEL
317
+ )
318
 
319
+ # Generator ์ดˆ๊ธฐํ™”
 
320
  self.generator = GGUFGenerator(
321
  model_path=self.config.GGUF_MODEL_PATH,
322
+ n_gpu_layers=n_gpu_layers,
323
+ n_ctx=n_ctx,
324
+ n_threads=n_threads,
325
+ config=self.config,
326
+ max_new_tokens=max_new_tokens,
327
+ temperature=temperature,
328
+ top_p=top_p,
329
  system_prompt=self.config.SYSTEM_PROMPT
330
  )
331
 
 
481
  answer = self.generator.chat(
482
  question=query,
483
  context=context,
484
+ system_prompt=system_prompt
485
  )
486
 
487
  elapsed_time = time.time() - start_time
 
495
  'answer': answer,
496
  'sources': self._format_sources(self._last_retrieved_docs),
497
  'used_retrieval': used_retrieval,
498
+ 'query_type': query_type,
499
  'search_mode': self.search_mode if used_retrieval else 'direct',
500
  'routing_info': classification,
501
  'elapsed_time': elapsed_time,