Dongjin1203 commited on
Commit
f7b258b
·
1 Parent(s): b59368d

이전으로 롤백

Browse files
Files changed (1) hide show
  1. src/visualization/chatbot_app.py +2 -279
src/visualization/chatbot_app.py CHANGED
@@ -151,15 +151,6 @@ if 'available_models' not in st.session_state:
151
  if 'selected_gpt_model' not in st.session_state:
152
  st.session_state.selected_gpt_model = "gpt-4o-mini"
153
 
154
- if 'custom_db_path' not in st.session_state:
155
- st.session_state.custom_db_path = None
156
-
157
- if 'db_uploaded' not in st.session_state:
158
- st.session_state.db_uploaded = False
159
-
160
- if 'last_db_file' not in st.session_state:
161
- st.session_state.last_db_file = None
162
-
163
 
164
  # ===== API 키로 사용 가능한 모델 조회 함수 =====
165
  def get_available_models(api_key: str) -> tuple:
@@ -297,130 +288,9 @@ def validate_api_key(api_key: str) -> tuple:
297
  return False, f"❌ API 키 검증 실패: {error_msg}", []
298
 
299
 
300
- # ===== 벡터 DB 업로드 및 검증 함수 =====
301
- def upload_and_extract_vectordb(uploaded_file):
302
- """
303
- 업로드된 ZIP 파일을 압축 해제하고 ChromaDB 경로 반환
304
-
305
- Args:
306
- uploaded_file: Streamlit UploadedFile 객체
307
-
308
- Returns:
309
- Path: ChromaDB 경로 (chroma.sqlite3가 있는 폴더)
310
-
311
- Raises:
312
- FileNotFoundError: chroma.sqlite3를 찾을 수 없는 경우
313
- """
314
- import zipfile
315
- import tempfile
316
- from pathlib import Path
317
-
318
- # 임시 폴더 생성
319
- temp_dir = tempfile.mkdtemp(prefix="chroma_db_")
320
- temp_path = Path(temp_dir)
321
-
322
- # ZIP 파일 저장
323
- zip_path = temp_path / "uploaded.zip"
324
- with open(zip_path, "wb") as f:
325
- f.write(uploaded_file.getbuffer())
326
-
327
- # 압축 해제
328
- extract_path = temp_path / "chromadb"
329
- extract_path.mkdir(exist_ok=True)
330
-
331
- with zipfile.ZipFile(zip_path, 'r') as zip_ref:
332
- zip_ref.extractall(extract_path)
333
-
334
- # chroma.sqlite3 찾기
335
- # 경우 1: 루트에 있는 경우
336
- if (extract_path / "chroma.sqlite3").exists():
337
- return str(extract_path)
338
-
339
- # 경우 2: 하위 폴더에 있는 경우
340
- for root in extract_path.rglob("*"):
341
- if root.is_dir():
342
- if (root / "chroma.sqlite3").exists():
343
- return str(root)
344
-
345
- # 찾지 못한 경우
346
- raise FileNotFoundError(
347
- "chroma.sqlite3를 찾을 수 없습니다. "
348
- "올바른 ChromaDB 폴더를 압축했는지 확인하세요."
349
- )
350
-
351
-
352
- def get_vectordb_info(db_path: str) -> dict:
353
- """
354
- 벡터 DB 정보 조회 (chroma_parser.py의 ChromaDBParser 방식)
355
-
356
- Args:
357
- db_path: ChromaDB 경로
358
-
359
- Returns:
360
- 정보 딕셔너리
361
- """
362
- try:
363
- import chromadb
364
- from pathlib import Path
365
-
366
- # ChromaDB 클라이언트 생성
367
- client = chromadb.PersistentClient(path=str(Path(db_path)))
368
-
369
- # Collection 리스트 가져오기
370
- collections = client.list_collections()
371
-
372
- if not collections:
373
- return {
374
- 'doc_count': 0,
375
- 'metadata_keys': [],
376
- 'collection_name': 'N/A',
377
- 'error': 'Collection이 없습니다'
378
- }
379
-
380
- # 첫 번째 Collection 사용
381
- collection = collections[0]
382
- collection_name = collection.name
383
-
384
- # 총 문서 수
385
- count = collection.count()
386
-
387
- if count == 0:
388
- return {
389
- 'doc_count': 0,
390
- 'metadata_keys': [],
391
- 'collection_name': collection_name
392
- }
393
-
394
- # 샘플 데이터로 정보 확인
395
- sample = collection.get(
396
- limit=1,
397
- include=['metadatas']
398
- )
399
-
400
- # 메타데이터 키
401
- metadata_keys = []
402
- if sample.get('metadatas') and len(sample['metadatas']) > 0:
403
- if sample['metadatas'][0]:
404
- metadata_keys = list(sample['metadatas'][0].keys())
405
-
406
- return {
407
- 'doc_count': count,
408
- 'metadata_keys': metadata_keys,
409
- 'collection_name': collection_name
410
- }
411
-
412
- except Exception as e:
413
- return {
414
- 'doc_count': 0,
415
- 'metadata_keys': [],
416
- 'collection_name': 'N/A',
417
- 'error': str(e)
418
- }
419
-
420
-
421
  # ===== RAG 파이프라인 초기화 =====
422
  @st.cache_resource
423
- def initialize_rag(model_type, _user_api_key=None, gpt_model_name=None, custom_db_path=None):
424
  """
425
  RAG 파이프라인 초기화
426
 
@@ -428,7 +298,6 @@ def initialize_rag(model_type, _user_api_key=None, gpt_model_name=None, custom_d
428
  model_type: "API 모델 (GPT)" 또는 "로컬 모델 (GGUF)"
429
  _user_api_key: 사용자가 입력한 API 키 (None이면 .env 사용)
430
  gpt_model_name: 사용할 GPT 모델 이름 (예: "gpt-4o-mini")
431
- custom_db_path: 사용자가 업로드한 벡터 DB 경로 (None이면 기본 경로)
432
 
433
  Returns:
434
  (rag_pipeline, error_message, model_name)
@@ -445,10 +314,6 @@ def initialize_rag(model_type, _user_api_key=None, gpt_model_name=None, custom_d
445
  if gpt_model_name:
446
  config.LLM_MODEL_NAME = gpt_model_name
447
 
448
- # 커스텀 벡터 DB 경로 설정
449
- if custom_db_path:
450
- config.DB_DIRECTORY = custom_db_path
451
-
452
  if model_type == "API 모델 (GPT)":
453
  # API 모델 사용
454
  from src.generator.generator import RAGPipeline
@@ -739,147 +604,6 @@ def main():
739
 
740
  st.markdown("---")
741
 
742
- # ===== 📊 벡터 DB 설정 =====
743
- st.markdown("### 📊 벡터 DB 설정")
744
-
745
- # 현재 DB 상태 확인
746
- has_server_db = os.path.exists(config.DB_DIRECTORY)
747
-
748
- if has_server_db:
749
- st.success("✅ 서버 벡터 DB 사용 중")
750
- else:
751
- st.warning("⚠️ 서버 벡터 DB가 없습니다. 아래에 업로드하세요.")
752
-
753
- # 벡터 DB 업로드 옵션
754
- use_custom_db = st.checkbox(
755
- "📤 내 벡터 DB 업로드하기",
756
- value=not has_server_db,
757
- help="자신의 ChromaDB를 ZIP 파일로 업로드하여 사용합니다."
758
- )
759
-
760
- if use_custom_db:
761
- st.markdown("""
762
- **업로드 방법:**
763
- 1. ChromaDB 폴더를 ZIP으로 압축
764
- 2. 아래에 업로드
765
-
766
- **필수 파일:**
767
- - `chroma.sqlite3`
768
- - Collection 폴더
769
- """)
770
-
771
- uploaded_db = st.file_uploader(
772
- "ChromaDB ZIP 파일 업로드",
773
- type=['zip'],
774
- help="chroma_db 폴더를 압축한 ZIP 파일을 업로드하세요",
775
- key="vectordb_uploader"
776
- )
777
-
778
- # 파일 업로드 처리
779
- if uploaded_db is not None:
780
- # 새 파일이거나 처음 업로드
781
- if ('last_db_file' not in st.session_state or
782
- st.session_state.last_db_file != uploaded_db.name):
783
-
784
- with st.spinner("📦 파일 처리 중..."):
785
- try:
786
- db_path = upload_and_extract_vectordb(uploaded_db)
787
- st.session_state.custom_db_path = db_path
788
- st.session_state.last_db_file = uploaded_db.name
789
- st.session_state.db_uploaded = True
790
-
791
- # RAG 파이프라인 재초기화 강제
792
- st.session_state.rag_pipeline = None
793
- st.session_state.model_type = None
794
-
795
- st.success("✅ 벡터 DB 업로드 완료!")
796
-
797
- # DB 정보 표시
798
- db_info = get_vectordb_info(db_path)
799
-
800
- if 'error' not in db_info:
801
- st.info(f"""
802
- 📋 **DB 정보:**
803
- - 문서 수: {db_info['doc_count']:,}개
804
- - 컬렉션: {db_info['collection_name']}
805
- - 메타데이터: {', '.join(db_info['metadata_keys'][:5])}
806
- """)
807
- else:
808
- st.warning(f"⚠️ DB 정보 조회 실패: {db_info['error']}")
809
-
810
- st.info("💡 모델을 다시 선택하면 새 벡터 DB로 초기화됩니다.")
811
-
812
- except FileNotFoundError as e:
813
- st.error(str(e))
814
- except Exception as e:
815
- st.error(f"❌ 업로드 실패: {e}")
816
-
817
- else:
818
- # 이미 업로드된 파일
819
- st.success(f"✅ 업로드됨: {uploaded_db.name}")
820
- if st.session_state.custom_db_path:
821
- st.info(f"경로: {st.session_state.custom_db_path}")
822
-
823
- # 벡터 DB 생성 가이드
824
- with st.expander("📖 벡터 DB 생성 방법"):
825
- st.markdown("""
826
- **1. 데이터 준비**
827
- ```bash
828
- # 문서 파일을 data/files/ 폴더에 저장
829
- ```
830
-
831
- **2. 벡터 DB 생성**
832
- ```bash
833
- # 전체 파이프라인 실행
834
- python main.py --step all
835
-
836
- # 또는 임베딩만
837
- python main.py --step embed
838
- ```
839
-
840
- **3. ZIP 압축**
841
- ```bash
842
- # Windows
843
- Compress-Archive -Path chroma_db -DestinationPath chroma_db.zip
844
-
845
- # Mac/Linux
846
- zip -r chroma_db.zip chroma_db/
847
- ```
848
-
849
- **4. 업로드**
850
- - 생성된 `chroma_db.zip` 파일을 위에서 업로드
851
- """)
852
-
853
- else:
854
- # 서버 DB 사용 중
855
- if has_server_db:
856
- st.info("ℹ️ 서버에 있는 벡터 DB를 사용합니다.")
857
-
858
- # 서버 DB 정보 표시
859
- if st.button("🔍 DB 정보 보기", key="view_server_db"):
860
- with st.spinner("🔄 정보 조회 중..."):
861
- db_info = get_vectordb_info(config.DB_DIRECTORY)
862
-
863
- if 'error' in db_info:
864
- st.error(f"❌ 정보 조회 실패: {db_info['error']}")
865
- else:
866
- st.success(f"""
867
- 📋 **서버 DB 정보:**
868
- - 문서 수: {db_info['doc_count']:,}개
869
- - 컬렉션: {db_info['collection_name']}
870
- - 메타데이터: {', '.join(db_info['metadata_keys'][:5])}
871
- """)
872
-
873
- # 사용자 DB 초기화
874
- if st.session_state.custom_db_path:
875
- st.session_state.custom_db_path = None
876
- st.session_state.db_uploaded = False
877
- st.session_state.last_db_file = None
878
- st.session_state.rag_pipeline = None
879
- st.session_state.model_type = None
880
-
881
- st.markdown("---")
882
-
883
  # ===== 🤖 모델 설정 =====
884
  st.markdown("### 🤖 모델 설정")
885
 
@@ -1119,8 +843,7 @@ def main():
1119
  rag, error, rag_type = initialize_rag(
1120
  model_type,
1121
  _user_api_key=st.session_state.user_api_key,
1122
- gpt_model_name=selected_gpt_model,
1123
- custom_db_path=st.session_state.custom_db_path
1124
  )
1125
 
1126
  if error:
 
151
  if 'selected_gpt_model' not in st.session_state:
152
  st.session_state.selected_gpt_model = "gpt-4o-mini"
153
 
 
 
 
 
 
 
 
 
 
154
 
155
  # ===== API 키로 사용 가능한 모델 조회 함수 =====
156
  def get_available_models(api_key: str) -> tuple:
 
288
  return False, f"❌ API 키 검증 실패: {error_msg}", []
289
 
290
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  # ===== RAG 파이프라인 초기화 =====
292
  @st.cache_resource
293
+ def initialize_rag(model_type, _user_api_key=None, gpt_model_name=None):
294
  """
295
  RAG 파이프라인 초기화
296
 
 
298
  model_type: "API 모델 (GPT)" 또는 "로컬 모델 (GGUF)"
299
  _user_api_key: 사용자가 입력한 API 키 (None이면 .env 사용)
300
  gpt_model_name: 사용할 GPT 모델 이름 (예: "gpt-4o-mini")
 
301
 
302
  Returns:
303
  (rag_pipeline, error_message, model_name)
 
314
  if gpt_model_name:
315
  config.LLM_MODEL_NAME = gpt_model_name
316
 
 
 
 
 
317
  if model_type == "API 모델 (GPT)":
318
  # API 모델 사용
319
  from src.generator.generator import RAGPipeline
 
604
 
605
  st.markdown("---")
606
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
607
  # ===== 🤖 모델 설정 =====
608
  st.markdown("### 🤖 모델 설정")
609
 
 
843
  rag, error, rag_type = initialize_rag(
844
  model_type,
845
  _user_api_key=st.session_state.user_api_key,
846
+ gpt_model_name=selected_gpt_model
 
847
  )
848
 
849
  if error: