clovax-tax-chatbot / local_cpu_test.py
bissal's picture
large modification
3a15ede
# local_cpu_test.py - ๋กœ์ปฌ CPU ํ™˜๊ฒฝ์šฉ RAG ํ…Œ์ŠคํŠธ ์Šคํฌ๋ฆฝํŠธ
"""
๋กœ์ปฌ CPU ํ™˜๊ฒฝ์—์„œ RAG ์‹œ์Šคํ…œ์„ ํ…Œ์ŠคํŠธํ•˜๋Š” ์Šคํฌ๋ฆฝํŠธ
ํ—ˆ๊น…ํŽ˜์ด์Šค ์ŠคํŽ˜์ด์Šค ๋ฐฐํฌ ์ „ ๋กœ์ปฌ ๊ฒ€์ฆ์šฉ
"""
import os
import sys
import time
from datetime import datetime
# ๋กœ์ปฌ ํ™˜๊ฒฝ์ž„์„ ๋ช…์‹œ์ ์œผ๋กœ ์„ค์ •
os.environ.pop('SPACE_ID', None) # ํ—ˆ๊น…ํŽ˜์ด์Šค ํ™˜๊ฒฝ ๋ณ€์ˆ˜ ์ œ๊ฑฐ
def test_imports():
"""ํ•„์ˆ˜ ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ import ํ…Œ์ŠคํŠธ"""
print("=" * 50)
print("๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ import ํ…Œ์ŠคํŠธ")
print("=" * 50)
try:
import torch
print(f"โœ“ PyTorch {torch.__version__}")
print(f" CUDA ์‚ฌ์šฉ ๊ฐ€๋Šฅ: {torch.cuda.is_available()}")
print(f" ๋””๋ฐ”์ด์Šค: {'cuda' if torch.cuda.is_available() else 'cpu'}")
except ImportError as e:
print(f"โœ— PyTorch import ์‹คํŒจ: {e}")
return False
try:
import sentence_transformers
print(f"โœ“ sentence-transformers {sentence_transformers.__version__}")
except ImportError as e:
print(f"โœ— sentence-transformers import ์‹คํŒจ: {e}")
return False
try:
import faiss
print(f"โœ“ FAISS ์‚ฌ์šฉ ๊ฐ€๋Šฅ")
except ImportError as e:
print(f"โœ— FAISS import ์‹คํŒจ: {e}")
return False
try:
import sklearn
print(f"โœ“ scikit-learn {sklearn.__version__}")
except ImportError as e:
print(f"โœ— scikit-learn import ์‹คํŒจ: {e}")
return False
print("โœ“ ๋ชจ๋“  ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ import ์„ฑ๊ณต")
return True
def test_config():
"""์„ค์ • ํ…Œ์ŠคํŠธ"""
print("\n" + "=" * 50)
print("์„ค์ • ํ…Œ์ŠคํŠธ")
print("=" * 50)
try:
from config import RAG_CONFIG, IS_HUGGINGFACE_SPACE
print(f"ํ—ˆ๊น…ํŽ˜์ด์Šค ํ™˜๊ฒฝ: {IS_HUGGINGFACE_SPACE}")
print(f"์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ: {RAG_CONFIG['embedding_models'][0]}")
print(f"๋ฐฐ์น˜ ํฌ๊ธฐ: {RAG_CONFIG['batch_size']}")
print(f"Top-K: {RAG_CONFIG['top_k']}")
print(f"์ž„๊ณ„๊ฐ’: {RAG_CONFIG['similarity_threshold']}")
print(f"ํ•˜์ด๋ธŒ๋ฆฌ๋“œ ๊ฐ€์ค‘์น˜: {RAG_CONFIG['hybrid_weights']}")
return True
except Exception as e:
print(f"์„ค์ • ๋กœ๋“œ ์‹คํŒจ: {e}")
return False
def test_embedding_model():
"""์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ๋กœ๋“œ ํ…Œ์ŠคํŠธ"""
print("\n" + "=" * 50)
print("์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ")
print("=" * 50)
try:
from sentence_transformers import SentenceTransformer
from config import RAG_CONFIG
# ๊ฐ€์žฅ ๊ฐ€๋ฒผ์šด ๋ชจ๋ธ๋กœ ํ…Œ์ŠคํŠธ
model_name = 'paraphrase-multilingual-MiniLM-L12-v2'
print(f"๋ชจ๋ธ ๋กœ๋”ฉ: {model_name}")
start_time = time.time()
model = SentenceTransformer(
model_name,
device='cpu', # ๋ช…์‹œ์ ์œผ๋กœ CPU ์‚ฌ์šฉ
cache_folder='./model_cache'
)
load_time = time.time() - start_time
print(f"โœ“ ๋ชจ๋ธ ๋กœ๋“œ ์™„๋ฃŒ ({load_time:.2f}์ดˆ)")
# ๊ฐ„๋‹จํ•œ ์ž„๋ฒ ๋”ฉ ํ…Œ์ŠคํŠธ
test_texts = ["์ทจ๋“์„ธ์œจ์ด ์–ผ๋งˆ์ธ๊ฐ€์š”?", "์ฃผํƒ ๊ตฌ์ž…์‹œ ์„ธ๊ธˆ์€?"]
start_time = time.time()
embeddings = model.encode(test_texts, convert_to_numpy=True)
encode_time = time.time() - start_time
print(f"โœ“ ์ž„๋ฒ ๋”ฉ ์ƒ์„ฑ ์™„๋ฃŒ ({encode_time:.2f}์ดˆ)")
print(f" ์ž„๋ฒ ๋”ฉ ํ˜•ํƒœ: {embeddings.shape}")
return True
except Exception as e:
print(f"์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ ํ…Œ์ŠคํŠธ ์‹คํŒจ: {e}")
return False
def test_law_fetcher():
"""๋ฒ•๋ น ํŽ˜์ฒ˜ ํ…Œ์ŠคํŠธ"""
print("\n" + "=" * 50)
print("๋ฒ•๋ น ํŽ˜์ฒ˜ ํ…Œ์ŠคํŠธ")
print("=" * 50)
try:
from law_fetcher import HFLawAPIFetcher
fetcher = HFLawAPIFetcher()
print(f"โœ“ ๋ฒ•๋ น ํŽ˜์ฒ˜ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ")
print(f" ์บ์‹œ ๋””๋ ‰ํ† ๋ฆฌ: {fetcher.cache_dir}")
print(f" ์บ์‹œ๋œ ๋ฒ•๋ น: {len(fetcher.cache_info)}๊ฐœ")
# ์บ์‹œ๊ฐ€ ์žˆ๋Š”์ง€ ํ™•์ธ
if fetcher.cache_info:
print(" ์บ์‹œ ์ •๋ณด:")
for law_name, info in fetcher.cache_info.items():
cached_at = info.get('cached_at', 'Unknown')
data_size = info.get('data_size', 0)
print(f" - {law_name}: {data_size/1024:.1f}KB ({cached_at[:10]})")
return True
except Exception as e:
print(f"๋ฒ•๋ น ํŽ˜์ฒ˜ ํ…Œ์ŠคํŠธ ์‹คํŒจ: {e}")
return False
def test_rag_system_minimal():
"""RAG ์‹œ์Šคํ…œ ์ตœ์†Œ ํ…Œ์ŠคํŠธ"""
print("\n" + "=" * 50)
print("RAG ์‹œ์Šคํ…œ ์ตœ์†Œ ํ…Œ์ŠคํŠธ")
print("=" * 50)
try:
from rag_system import HFSpacesTaxRAG
print("RAG ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™” ์ค‘...")
start_time = time.time()
rag = HFSpacesTaxRAG()
init_time = time.time() - start_time
print(f"โœ“ RAG ์‹œ์Šคํ…œ ์ดˆ๊ธฐํ™” ์™„๋ฃŒ ({init_time:.2f}์ดˆ)")
print(f" ํ™˜๊ฒฝ: {'ํ—ˆ๊น…ํŽ˜์ด์Šค' if rag.is_huggingface_space else '๋กœ์ปฌ'}")
print(f" ๋””๋ฐ”์ด์Šค: {rag.device}")
print(f" ์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ: {rag.embedding_model.get_sentence_embedding_dimension()}์ฐจ์›")
# ๋ฒกํ„ฐ DB ์ƒํƒœ ํ™•์ธ
if rag.vector_db and rag.documents:
print(f" ๋ฒกํ„ฐ DB: {rag.vector_db.ntotal}๊ฐœ ๋ฒกํ„ฐ")
print(f" ๋ฌธ์„œ ์ˆ˜: {len(rag.documents)}๊ฐœ")
# ๊ฐ„๋‹จํ•œ ๊ฒ€์ƒ‰ ํ…Œ์ŠคํŠธ
test_query = "์ทจ๋“์„ธ ์„ธ์œจ"
print(f"\n๊ฒ€์ƒ‰ ํ…Œ์ŠคํŠธ: '{test_query}'")
start_time = time.time()
results = rag.search(test_query, top_k=2)
search_time = time.time() - start_time
print(f"โœ“ ๊ฒ€์ƒ‰ ์™„๋ฃŒ ({search_time:.2f}์ดˆ)")
print(f" ๊ฒฐ๊ณผ: {len(results)}๊ฐœ ๋ฌธ์„œ")
for i, result in enumerate(results):
score = result['hybrid_score']
doc_preview = result['document'][:50]
print(f" {i+1}. ์ ์ˆ˜: {score:.3f} - {doc_preview}...")
else:
print(" ๋ฒกํ„ฐ DB ์—†์Œ - ์‹œ์Šคํ…œ ๊ตฌ์ถ• ํ•„์š”")
return True
except Exception as e:
print(f"RAG ์‹œ์Šคํ…œ ํ…Œ์ŠคํŠธ ์‹คํŒจ: {e}")
import traceback
traceback.print_exc()
return False
def run_full_test():
"""์ „์ฒด ํ…Œ์ŠคํŠธ ์‹คํ–‰"""
print("๋กœ์ปฌ CPU ํ™˜๊ฒฝ RAG ์‹œ์Šคํ…œ ํ…Œ์ŠคํŠธ")
print(f"์‹คํ–‰ ์‹œ๊ฐ„: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print(f"Python ๋ฒ„์ „: {sys.version}")
tests = [
("๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ Import", test_imports),
("์„ค์ •", test_config),
("์ž„๋ฒ ๋”ฉ ๋ชจ๋ธ", test_embedding_model),
("๋ฒ•๋ น ํŽ˜์ฒ˜", test_law_fetcher),
("RAG ์‹œ์Šคํ…œ", test_rag_system_minimal)
]
success_count = 0
total_start = time.time()
for test_name, test_func in tests:
try:
print(f"\n{'='*60}")
print(f"ํ…Œ์ŠคํŠธ: {test_name}")
print(f"{'='*60}")
if test_func():
success_count += 1
print(f"โœ“ {test_name} ํ…Œ์ŠคํŠธ ์„ฑ๊ณต")
else:
print(f"โœ— {test_name} ํ…Œ์ŠคํŠธ ์‹คํŒจ")
except KeyboardInterrupt:
print(f"\n์‚ฌ์šฉ์ž์— ์˜ํ•ด ํ…Œ์ŠคํŠธ ์ค‘๋‹จ")
break
except Exception as e:
print(f"โœ— {test_name} ํ…Œ์ŠคํŠธ ์ค‘ ์˜ค๋ฅ˜: {e}")
total_time = time.time() - total_start
print(f"\n{'='*60}")
print(f"ํ…Œ์ŠคํŠธ ๊ฒฐ๊ณผ ์š”์•ฝ")
print(f"{'='*60}")
print(f"์„ฑ๊ณต: {success_count}/{len(tests)}")
print(f"์ด ์†Œ์š”์‹œ๊ฐ„: {total_time:.2f}์ดˆ")
if success_count == len(tests):
print("โœ“ ๋ชจ๋“  ํ…Œ์ŠคํŠธ ํ†ต๊ณผ - ํ—ˆ๊น…ํŽ˜์ด์Šค ๋ฐฐํฌ ์ค€๋น„ ์™„๋ฃŒ")
return True
else:
print("โœ— ์ผ๋ถ€ ํ…Œ์ŠคํŠธ ์‹คํŒจ - ๋ฌธ์ œ ํ•ด๊ฒฐ ํ›„ ์žฌ์‹œ๋„")
return False
if __name__ == "__main__":
# ๋ช…๋ นํ–‰ ์ธ์ˆ˜ ์ฒ˜๋ฆฌ
if len(sys.argv) > 1:
if sys.argv[1] == '--imports':
success = test_imports()
elif sys.argv[1] == '--config':
success = test_config()
elif sys.argv[1] == '--embedding':
success = test_embedding_model()
elif sys.argv[1] == '--rag':
success = test_rag_system_minimal()
else:
print("์‚ฌ์šฉ๋ฒ•: python local_cpu_test.py [--imports|--config|--embedding|--rag]")
sys.exit(1)
else:
success = run_full_test()
sys.exit(0 if success else 1)