Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Lily LLM API ์๋ฒ v2 (์ธํฐ๋ํฐ๋ธ ์ ํ ๋ณต์ ๋ฐ ์ฑ๋ฅ ์ต์ ํ ์ต์ข ๋ณธ) | |
| """ | |
| from fastapi import FastAPI, HTTPException, Request, UploadFile, File, Form, Depends, WebSocket, WebSocketDisconnect | |
| from fastapi.security import HTTPAuthorizationCredentials | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import uvicorn | |
| import logging | |
| import time | |
| import torch | |
| from datetime import datetime | |
| from typing import Optional, List, Union | |
| import asyncio | |
| import concurrent.futures | |
| import sys | |
| from PIL import Image | |
| import io | |
| import os | |
| import json | |
| from pathlib import Path | |
| import warnings | |
| # ๐ RoPE ๊ฒฝ๊ณ ์จ๊ธฐ๊ธฐ (Kanana ๋ชจ๋ธ ๋ด๋ถ ๊ตฌํ ๊ด๋ จ) | |
| warnings.filterwarnings("ignore", message="The attention layers in this model are transitioning") | |
| warnings.filterwarnings("ignore", message="rotary_pos_emb will be removed") | |
| warnings.filterwarnings("ignore", message="position_embeddings will be mandatory") | |
| # logging ์ค์ ์ ๋จผ์ ๊ตฌ์ฑ | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| from .models import get_model_profile, list_available_models | |
| from lily_llm_core.rag_processor import rag_processor | |
| from lily_llm_core.document_processor import document_processor | |
| from lily_llm_core.hybrid_prompt_generator import hybrid_prompt_generator | |
| from lily_llm_core.database import db_manager | |
| from lily_llm_core.auth_manager import auth_manager | |
| from lily_llm_core.websocket_manager import connection_manager | |
| from lily_llm_core.celery_app import ( | |
| process_document_async, generate_ai_response_async, | |
| rag_query_async, batch_process_documents_async, | |
| get_task_status, cancel_task | |
| ) | |
| from lily_llm_core.performance_monitor import performance_monitor | |
| # ์ด๋ฏธ์ง OCR ์ ์ฉ ๋ชจ๋ ์ถ๊ฐ | |
| from lily_llm_core.image_rag_processor import image_rag_processor | |
| from lily_llm_core.latex_rag_processor import latex_rag_processor | |
| from lily_llm_core.vector_store_manager import vector_store_manager | |
| # LaTeX-OCR + FAISS ํตํฉ ์์คํ ์ถ๊ฐ | |
| # from latex_ocr_faiss_integrated import LatexOCRFAISSIntegrated | |
| # from latex_ocr_faiss_simple import LatexOCRFAISSSimple | |
| # ๋ฉํฐ๋ชจ๋ฌ RAG ํ๋ก์ธ์ ์ถ๊ฐ | |
| from lily_llm_core.hybrid_rag_processor import hybrid_rag_processor | |
| # ์ปจํ ์คํธ ๊ด๋ฆฌ์ ๋ฐ LoRA ๊ด๋ฆฌ์ ์ถ๊ฐ | |
| from lily_llm_core.context_manager import get_context_manager, context_manager | |
| # ๊ณ์ธต์ ๋ฉ๋ชจ๋ฆฌ ์์คํ ์ถ๊ฐ | |
| from lily_llm_core.integrated_memory_manager import integrated_memory_manager | |
| from lily_llm_core.text_summarizer import text_summarizer, SummaryConfig | |
| # ์ ์ญ ๋ณ์๋ค | |
| current_model = None # ๐ ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ์ธ์คํด์ค | |
| current_profile = None # ๐ ํ์ฌ ์ ํ๋ ๋ชจ๋ธ ํ๋กํ | |
| model_loaded = False # ๐ ๋ชจ๋ธ ๋ก๋ ์ํ | |
| # LoRA ๊ด๋ฆฌ์ import (์ ํ์ ) | |
| try: | |
| from lily_llm_core.lora_manager import get_lora_manager, lora_manager | |
| LORA_AVAILABLE = True | |
| logger.info("โ LoRA ๊ด๋ฆฌ์ import ์ฑ๊ณต") | |
| except ImportError as e: | |
| logger.warning(f"โ ๏ธ LoRA ๊ด๋ฆฌ์ import ์คํจ: {e}") | |
| LORA_AVAILABLE = False | |
| lora_manager = None | |
| get_lora_manager = None | |
| # ===== ๊ณตํต LoRA ์ค์ ํจ์ ===== | |
| def setup_lora_for_model(profile, lora_manager): | |
| """๋ชจ๋ธ ํ๋กํ์ ๋ฐ๋ฅธ LoRA ์ค์ (๊ณตํต ํจ์)""" | |
| if not LORA_AVAILABLE or not lora_manager: | |
| logger.warning("โ ๏ธ LoRA๊ฐ ์ฌ์ฉ ๋ถ๊ฐ๋ฅํ์ฌ ์๋ ์ค์ ๊ฑด๋๋") | |
| return False | |
| try: | |
| logger.info("๐ง LoRA ์๋ ์ค์ ์์...") | |
| # ๐ ๋ชจ๋ธ ํ๋กํ์์ ๊ฒฝ๋ก ๋ฐ ํ์ ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ | |
| current_model_path = None | |
| model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ | |
| # ๐ ๋ชจ๋ธ ํ๋กํ์์ ๊ฒฝ๋ก ๋ฐ ํ์ ์ ๋ณด ๊ฐ์ ธ์ค๊ธฐ | |
| if hasattr(profile, 'local_path') and profile.local_path: | |
| # ๋ก์ปฌ ํ๊ฒฝ: ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ | |
| current_model_path = profile.local_path | |
| # ๐ local_path ์ฌ์ฉ ์์๋ model_type ์ค์ ํ์ | |
| if hasattr(profile, 'model_id') and profile.model_id: | |
| model_id = profile.model_id | |
| if model_id == "kanana-1.5-v-3b-instruct": | |
| model_type = "vision2seq" # ๐ kanana๋ vision2seq ํ์ | |
| else: | |
| model_type = "causal_lm" # ๊ธฐ๋ณธ๊ฐ | |
| logger.info(f"๐ ๋ชจ๋ธ ํ๋กํ์์ ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ: {current_model_path}") | |
| logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ํ์ : {model_type}") | |
| elif hasattr(profile, 'model_id') and profile.model_id: | |
| # ๋ชจ๋ธ ID๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ๊ฒฝ๋ก ๊ฒฐ์ | |
| model_id = profile.model_id | |
| logger.info(f"๐ ๋ชจ๋ธ ID ๊ธฐ๋ฐ ๊ฒฝ๋ก ๊ฒฐ์ : {model_id}") | |
| # ๐ ํ๊ฒฝ์ ๋ฐ๋ฅธ ๊ฒฝ๋ก ๊ฒฐ์ | |
| if hasattr(profile, 'is_local') and profile.is_local: | |
| # ๋ก์ปฌ ํ๊ฒฝ: ๋ก์ปฌ ๊ฒฝ๋ก ์ฌ์ฉ | |
| if model_id == "polyglot-ko-1.3b-chat": | |
| current_model_path = "./lily_llm_core/models/polyglot_ko_1_3b_chat" | |
| model_type = "causal_lm" | |
| elif model_id == "kanana-1.5-v-3b-instruct": | |
| current_model_path = "./lily_llm_core/models/kanana_1_5_v_3b_instruct" | |
| model_type = "vision2seq" # ๐ kanana๋ vision2seq ํ์ | |
| elif model_id == "polyglot-ko-5.8b-chat": | |
| current_model_path = "./lily_llm_core/models/polyglot_ko_5_8b_chat" | |
| model_type = "causal_lm" | |
| else: | |
| # ๋ฐฐํฌ ํ๊ฒฝ: HF ๋ชจ๋ธ๋ช ์ฌ์ฉ (๋ก์ปฌ ๊ฒฝ๋ก ์์) | |
| current_model_path = None | |
| logger.info(f"๐ ๋ฐฐํฌ ํ๊ฒฝ: LoRA ์ค์ ๊ฑด๋๋ (HF ๋ชจ๋ธ)") | |
| return False | |
| logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ๊ฒฝ๋ก: {current_model_path}") | |
| logger.info(f"๐ ๊ฒฐ์ ๋ ๋ชจ๋ธ ํ์ : {model_type}") | |
| if not current_model_path: | |
| logger.warning("โ ๏ธ ํ์ฌ ๋ชจ๋ธ์ ๊ฒฝ๋ก๋ฅผ ์ฐพ์ ์ ์์ด LoRA ์๋ ๋ก๋ ๊ฑด๋๋") | |
| return False | |
| logger.info(f"๐ LoRA ๋ชจ๋ธ ๊ฒฝ๋ก: {current_model_path}") | |
| logger.info(f"๐ LoRA ๋ชจ๋ธ ํ์ : {model_type}") | |
| # ๐ ์ด๋ฏธ ๋ก๋๋ ๋ฉ์ธ ๋ชจ๋ธ์ LoRA์ ์ง์ ์ ์ฉ (์ค๋ณต ๋ก๋ ๋ฐฉ์ง) | |
| logger.info("๐ง ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ LoRA ์ง์ ์ ์ฉ ์์...") | |
| # ๐ lora_manager์ ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ ์ค์ | |
| if hasattr(lora_manager, 'base_model') and lora_manager.base_model is None: | |
| # ์ ์ญ ๋ณ์์์ ๋ฉ์ธ ๋ชจ๋ธ ๊ฐ์ ธ์ค๊ธฐ | |
| from lily_llm_api.app import current_model | |
| if current_model is not None: | |
| lora_manager.base_model = current_model | |
| logger.info("โ ๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ LoRA ๊ด๋ฆฌ์์ ์ค์ ์๋ฃ") | |
| else: | |
| logger.warning("โ ๏ธ ๋ฉ์ธ ๋ชจ๋ธ์ ์ฐพ์ ์ ์์ด LoRA ์ค์ ๊ฑด๋๋") | |
| return False | |
| # LoRA ์ค์ ์์ฑ | |
| logger.info("๐ง LoRA ์ค์ ์์ฑ ์์...") | |
| # ๐ ๋ชจ๋ธ๋ณ target modules ์ค์ | |
| if model_type == "vision2seq" and "kanana" in profile.model_id: | |
| # Kanana ๋ชจ๋ธ: Llama ๊ธฐ๋ฐ language model ์ฌ์ฉ (์ฒซ ๋ฒ์งธ ๋ ์ด์ด๋ง ์ฌ์ฉ) | |
| target_modules = [ | |
| "language_model.model.layers.0.self_attn.q_proj", | |
| "language_model.model.layers.0.self_attn.k_proj", | |
| "language_model.model.layers.0.self_attn.v_proj", | |
| "language_model.model.layers.0.self_attn.o_proj", | |
| "language_model.model.layers.0.mlp.gate_proj", | |
| "language_model.model.layers.0.mlp.up_proj", | |
| "language_model.model.layers.0.mlp.down_proj" | |
| ] | |
| else: | |
| # ๊ธฐ์กด ๋ชจ๋ธ๋ค: GPTNeoX ๊ธฐ๋ฐ | |
| target_modules = ["query_key_value", "mlp.dense_h_to_4h", "mlp.dense_4h_to_h"] | |
| lora_config = lora_manager.create_lora_config( | |
| r=16, | |
| lora_alpha=32, | |
| lora_dropout=0.1, | |
| bias="none", | |
| task_type="CAUSAL_LM" if model_type == "causal_lm" else "VISION_2_SEQ", | |
| target_modules=target_modules | |
| ) | |
| logger.info("โ LoRA ์ค์ ์์ฑ ์๋ฃ") | |
| # LoRA ์ด๋ํฐ ์ ์ฉ (๊ธฐ์กด ๋ฉ์ธ ๋ชจ๋ธ์ ์ง์ ) | |
| logger.info("๐ง LoRA ์ด๋ํฐ ์ ์ฉ ์์...") | |
| adapter_success = lora_manager.apply_lora_to_model("auto_adapter") | |
| if adapter_success: | |
| logger.info("โ LoRA ์ด๋ํฐ ์ ์ฉ ์๋ฃ: auto_adapter") | |
| logger.info("๐ LoRA ์๋ ์ค์ ์๋ฃ!") | |
| return True | |
| else: | |
| logger.error("โ LoRA ์ด๋ํฐ ์ ์ฉ ์คํจ") | |
| return False | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์๋ ์ค์ ์ค ์ค๋ฅ: {e}") | |
| return False | |
| # ===== lifespan ์ปจํ ์คํธ ๋งค๋์ (์๋ฒ ์์/์ข ๋ฃ ์ด๋ฒคํธ) ===== | |
| from contextlib import asynccontextmanager | |
| async def lifespan(app: FastAPI): | |
| """์๋ฒ ์๋ช ์ฃผ๊ธฐ ๊ด๋ฆฌ""" | |
| # ์๋ฒ ์์ ์ | |
| logger.info("๐ ์๋ฒ ์์ ์ด๋ฒคํธ ์คํ ์ค...") | |
| # CPU ์ค๋ ๋ ์ต์ ํ ์ ์ฉ | |
| try: | |
| configure_cpu_threads() | |
| logger.info("โ CPU ์ค๋ ๋ ์ต์ ํ ์๋ฃ") | |
| except Exception as e: | |
| logger.error(f"โ CPU ์ค๋ ๋ ์ค์ ์คํจ: {e}") | |
| # ๐ ๋ชจ๋ธ ์ ํ ๋ณต์: ์ฌ์ฉ์๊ฐ ๋ชจ๋ธ์ ์ ํํ ์ ์๋๋ก | |
| selected_model_id = select_model_interactive() | |
| logger.info(f"๐ ์๋ฒ ์์ ์ ์ ํ๋ ๋ชจ๋ธ: {selected_model_id}") | |
| try: | |
| await load_model_async(selected_model_id) | |
| global model_loaded | |
| model_loaded = True | |
| logger.info(f"โ ์๋ฒ๊ฐ '{current_profile.display_name}' ๋ชจ๋ธ๋ก ์ค๋น๋์์ต๋๋ค.") | |
| logger.info(f"โ model_loaded ์ํ: {model_loaded}") | |
| # ๐ ์ค๋ฌด์ฉ: ๊ณ ๊ธ ์ปจํ ์คํธ ๊ด๋ฆฌ์ ์ค์ | |
| try: | |
| # ์์ฝ ๋ฐฉ๋ฒ์ smart๋ก ์ค์ (๊ฐ์ฅ ๊ท ํ์กํ ์์ฝ) | |
| context_manager.set_summary_method("smart") | |
| logger.info("โ ๊ณ ๊ธ ์ปจํ ์คํธ ๊ด๋ฆฌ์ ์ค์ ์๋ฃ: smart ์์ฝ ๋ฐฉ๋ฒ ํ์ฑํ") | |
| # ์๋ ์ ๋ฆฌ ์ค์ (ํ๊ฒฝ๋ณ์๋ก ์ค๋ฒ๋ผ์ด๋) | |
| import os | |
| enabled = os.getenv('LILY_CONTEXT_AUTOCLEAN_ENABLED', '1') in ['1', 'true', 'True'] | |
| interval_turns = int(os.getenv('LILY_CONTEXT_AUTOCLEAN_TURNS', '12')) | |
| interval_time = int(os.getenv('LILY_CONTEXT_AUTOCLEAN_TIME', '600')) | |
| strategy = os.getenv('LILY_CONTEXT_CLEANUP_STRATEGY', 'smart') | |
| context_manager.set_auto_cleanup_config( | |
| enabled=enabled, | |
| interval_turns=interval_turns, | |
| interval_time=interval_time, | |
| strategy=strategy | |
| ) | |
| logger.info("โ ์๋ ์ ๋ฆฌ ์ค์ ์ต์ ํ ์๋ฃ") | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ ๊ณ ๊ธ ์ปจํ ์คํธ ๊ด๋ฆฌ์ ์ค์ ์คํจ: {e}") | |
| # ๐ LoRA ์๋ ์ค์ ์ load_model_async ๋ด๋ถ์์ ์ด๋ฏธ ์ฒ๋ฆฌ๋จ | |
| # setup_lora_for_model(current_profile, lora_manager) # ์ค๋ณต ํธ์ถ ์ ๊ฑฐ | |
| except Exception as e: | |
| logger.error(f"โ ๋ชจ๋ธ ๋ก๋์ ์คํจํ์ต๋๋ค: {e}", exc_info=True) | |
| model_loaded = False | |
| logger.info("โ ์๋ฒ ์์ ์ด๋ฒคํธ ์๋ฃ") | |
| yield # ์๋ฒ ์คํ ์ค | |
| # ์๋ฒ ์ข ๋ฃ ์ | |
| logger.info("๐ ์๋ฒ ์ข ๋ฃ ์ด๋ฒคํธ ์คํ ์ค...") | |
| logger.info("โ ์๋ฒ ์ข ๋ฃ ์ด๋ฒคํธ ์๋ฃ") | |
| # FastAPI ์ฑ ์์ฑ (lifespan ํฌํจ) | |
| app = FastAPI( | |
| title="Lily LLM API v2", | |
| description="๋ค์ค ๋ชจ๋ธ ์ง์ LLM API ์๋ฒ", | |
| version="2.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS ์ค์ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=[ | |
| "http://localhost:8001", | |
| "http://127.0.0.1:8001", | |
| "http://localhost:3000", | |
| "http://127.0.0.1:3000", | |
| "*" # ๊ฐ๋ฐ ์ค์๋ ๋ชจ๋ origin ํ์ฉ | |
| ], | |
| allow_credentials=True, | |
| allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], | |
| allow_headers=["*"], | |
| ) | |
| # Pydantic ๋ชจ๋ธ๋ค | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| model_id: Optional[str] = None # ๊ธฐ๋ณธ๊ฐ ์ ๊ฑฐ - ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ ์ฌ์ฉ | |
| max_length: Optional[int] = None | |
| temperature: Optional[float] = None | |
| top_p: Optional[float] = None | |
| do_sample: Optional[bool] = None | |
| class GenerateResponse(BaseModel): | |
| generated_text: str | |
| processing_time: float | |
| model_name: str | |
| image_processed: bool | |
| class MultimodalGenerateResponse(BaseModel): | |
| generated_text: str | |
| processing_time: float | |
| model_name: str | |
| model_id: Optional[str] = None | |
| image_processed: bool = False | |
| class HealthResponse(BaseModel): | |
| status: str | |
| model_loaded: bool | |
| current_model: str | |
| available_models: List[dict] | |
| class DocumentUploadResponse(BaseModel): | |
| success: bool | |
| document_id: str | |
| message: str | |
| chunks: Optional[int] = None | |
| latex_count: Optional[int] = None # LaTeX ์์ ๊ฐ์ ํ๋ ์ถ๊ฐ | |
| error: Optional[str] = None | |
| auto_response: Optional[str] = None # ์๋ ์๋ต ํ๋ ์ถ๊ฐ | |
| class RAGResponse(BaseModel): | |
| success: bool | |
| response: str | |
| context: str | |
| sources: List[dict] | |
| search_results: int | |
| processing_time: float | |
| # ์ฌ์ฉ์ ๊ด๋ จ ์๋ต ๋ชจ๋ธ | |
| class UserResponse(BaseModel): | |
| success: bool | |
| user_id: str | |
| username: Optional[str] = None | |
| email: Optional[str] = None | |
| created_at: Optional[str] = None | |
| error: Optional[str] = None | |
| class SessionResponse(BaseModel): | |
| success: bool | |
| session_id: str | |
| session_name: Optional[str] = None | |
| created_at: Optional[str] = None | |
| error: Optional[str] = None | |
| class ChatMessageResponse(BaseModel): | |
| success: bool | |
| message_id: int | |
| content: str | |
| message_type: str | |
| timestamp: str | |
| error: Optional[str] = None | |
| # ์ธ์ฆ ๊ด๋ จ ์๋ต ๋ชจ๋ธ | |
| class LoginResponse(BaseModel): | |
| success: bool | |
| access_token: Optional[str] = None | |
| refresh_token: Optional[str] = None | |
| token_type: Optional[str] = None | |
| user_id: Optional[str] = None | |
| username: Optional[str] = None | |
| error: Optional[str] = None | |
| class TokenResponse(BaseModel): | |
| success: bool | |
| access_token: Optional[str] = None | |
| token_type: Optional[str] = None | |
| error: Optional[str] = None | |
| # ์ ์ญ ๋ณ์ | |
| model = None | |
| tokenizer = None | |
| processor = None | |
| current_profile = None | |
| model_loaded = False | |
| image_processor = None | |
| executor = concurrent.futures.ThreadPoolExecutor() | |
| def configure_cpu_threads(): | |
| """CPU ์ค๋ ๋ ํ๊ฒฝ ์ต์ ํ (vCPU ์์ ๋ง๊ฒ ์กฐ์ ).""" | |
| print(f"๐ [DEBUG] configure_cpu_threads ์์") | |
| try: | |
| # ๊ธฐ๋ณธ๊ฐ: ํ๊ฒฝ๋ณ์ ๋๋ ์์คํ CPU ์๋ฅผ ์ฌ์ฉํ๋ ๊ณผ๋ํ ์ค๋ ๋ ๋ฐฉ์ง | |
| env_threads = os.getenv("CPU_THREADS") | |
| if env_threads is not None: | |
| threads = max(1, int(env_threads)) | |
| else: | |
| detected = os.cpu_count() or 2 | |
| # ์ปจํ ์ด๋/์๋ฒ์ vCPU ์๋ฅผ ๊ทธ๋๋ก ์ฌ์ฉํ๋ ์ํ 16 ์ ์ฉ | |
| threads = max(1, min(detected, 16)) | |
| # OpenMP/MKL/numexpr | |
| os.environ["OMP_NUM_THREADS"] = str(threads) | |
| os.environ["MKL_NUM_THREADS"] = str(threads) | |
| os.environ.setdefault("NUMEXPR_NUM_THREADS", str(threads)) | |
| os.environ.setdefault("TOKENIZERS_PARALLELISM", "false") | |
| # PyTorch ๋ด๋ถ ์ค๋ ๋ ์ค์ | |
| try: | |
| torch.set_num_threads(threads) | |
| except Exception: | |
| pass | |
| try: | |
| # ์ฐ์ฐ ๊ฐ ์ค๋ ๋ ํ์ 1~2 ๊ถ์ฅ(์ปจํ ์คํธ ์ค์์นญ ๋น์ฉ ์ ๊ฐ) | |
| torch.set_num_interop_threads(1 if threads <= 4 else 2) | |
| except Exception: | |
| pass | |
| logger.info(f"๐งต CPU thread config -> OMP/MKL/numexpr={threads}, torch_threads={threads}") | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ CPU ์ค๋ ๋ ์ค์ ์คํจ: {e}") | |
| print(f"๐ [DEBUG] configure_cpu_threads ์ข ๋ฃ") | |
| def select_model_interactive(): | |
| """์ธํฐ๋ํฐ๋ธ ๋ชจ๋ธ ์ ํ""" | |
| available_models = list_available_models() | |
| print("\n" + "="*60 + "\n๐ค Lily LLM API v2 - ๋ชจ๋ธ ์ ํ\n" + "="*60) | |
| for i, model_info in enumerate(available_models, 1): | |
| print(f"{i:2d}. {model_info['name']} ({model_info['model_id']})") | |
| while True: | |
| try: | |
| # choice = input(f"\n๐ ์ฌ์ฉํ ๋ชจ๋ธ ๋ฒํธ๋ฅผ ์ ํํ์ธ์ (1-{len(available_models)}): ") | |
| # selected_model = available_models[int(choice) - 1] | |
| selected_model = available_models[1] | |
| print(f"\nโ '{selected_model['name']}' ๋ชจ๋ธ์ ์ ํํ์ต๋๋ค.") | |
| return selected_model['model_id'] | |
| except (ValueError, IndexError): | |
| print(f"โ 1์์ {len(available_models)} ์ฌ์ด์ ์ซ์๋ฅผ ์ ๋ ฅํด์ฃผ์ธ์.") | |
| except KeyboardInterrupt: | |
| sys.exit("\n\n๐ ํ๋ก๊ทธ๋จ์ ์ข ๋ฃํฉ๋๋ค.") | |
| # @app.on_event("startup") - FastAPI ์ต์ ๋ฒ์ ์์ ์๋ํ์ง ์์ | |
| # startup_event ํจ์๋ lifespan์ผ๋ก ์ด๋๋จ | |
| def shutdown_event(): | |
| executor.shutdown(wait=True) | |
| async def load_model_async(model_id: str): | |
| loop = asyncio.get_event_loop() | |
| await loop.run_in_executor(executor, load_model_sync, model_id) | |
| async def load_model_endpoint(model_id: str): | |
| """๋ชจ๋ธ ๋ก๋ HTTP ์๋ํฌ์ธํธ""" | |
| try: | |
| logger.info(f"๐ฅ HTTP ์์ฒญ์ผ๋ก ๋ชจ๋ธ ๋ก๋ ์์: {model_id}") | |
| await load_model_async(model_id) | |
| return {"success": True, "message": f"๋ชจ๋ธ '{model_id}' ๋ก๋ ์๋ฃ"} | |
| except Exception as e: | |
| logger.error(f"โ HTTP ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| def load_model_sync(model_id: str): | |
| """๋ชจ๋ธ ๋ฐ ๊ด๋ จ ํ๋ก์ธ์๋ฅผ ๋๊ธฐ์ ์ผ๋ก ๋ก๋ฉ (์ต์ข ์์ ๋ณธ)""" | |
| global model, tokenizer, processor, current_profile, current_model | |
| try: | |
| if model is not None: | |
| logger.info("๐๏ธ ๊ธฐ์กด ๋ชจ๋ธ ์ธ๋ก๋ ์ค...") | |
| del model | |
| del tokenizer | |
| del processor | |
| model, tokenizer, processor = None, None, None | |
| import gc | |
| gc.collect() | |
| logger.info("โ ๊ธฐ์กด ๋ชจ๋ธ ์ธ๋ก๋ ์๋ฃ") | |
| logger.info(f"๐ฅ '{model_id}' ๋ชจ๋ธ ๋ก๋ฉ ์์...") | |
| current_profile = get_model_profile(model_id) | |
| # ์ด์ load_model์ (model, processor)๋ฅผ ๋ฐํํฉ๋๋ค. | |
| model, processor = current_profile.load_model() | |
| # ๐ ์ ์ญ ๋ณ์์ ๋ชจ๋ธ ์ค์ (LoRA์์ ์ฌ์ฉ) | |
| current_model = model | |
| # processor์์ tokenizer๋ฅผ ๊บผ๋ด ์ ์ญ ๋ณ์์ ํ ๋นํฉ๋๋ค. | |
| if hasattr(processor, 'tokenizer'): | |
| tokenizer = processor.tokenizer | |
| else: | |
| # processor ์์ฒด๊ฐ tokenizer ์ญํ ๋ ํ ์ ์๋ ๊ฒฝ์ฐ | |
| tokenizer = processor | |
| logger.info(f"โ '{current_profile.display_name}' ๋ชจ๋ธ ๋ก๋ฉ ์๋ฃ!") | |
| # ๐ LoRA ๊ธฐ๋ณธ ๋ชจ๋ธ ์๋ ๋ก๋ (๊ณตํต ํจ์ ์ฌ์ฉ) | |
| setup_lora_for_model(current_profile, lora_manager) | |
| except Exception as e: | |
| logger.error(f"โ load_model_sync ์คํจ: {e}") | |
| import traceback | |
| logger.error(f"๐ ์ ์ฒด ์๋ฌ: {traceback.format_exc()}") | |
| raise | |
| def generate_sync(prompt: str, image_data_list: Optional[List[bytes]], max_length: Optional[int] = None, | |
| temperature: Optional[float] = None, top_p: Optional[float] = None, | |
| do_sample: Optional[bool] = None, use_context: bool = True, session_id: str = None, | |
| user_id: str = "anonymous", room_id: str = "default") -> dict: | |
| """[์ต์ ํ] ๋ชจ๋ธ ์์ฑ์ ์ฒ๋ฆฌํ๋ ํตํฉ ๋๊ธฐ ํจ์""" | |
| try: | |
| print(f"๐ [DEBUG] generate_sync ์์ - prompt ๊ธธ์ด: {len(prompt)}") | |
| print(f"๐ [DEBUG] ํ์ฌ ๋ก๋๋ ๋ชจ๋ธ: {current_profile.display_name if current_profile else 'None'}") | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ํ์ : {type(current_profile) if current_profile else 'None'}") | |
| if current_profile is None: | |
| print("โ [DEBUG] ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์") | |
| return {"error": "No model loaded"} | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์ด๋ฆ: {getattr(current_profile, 'model_name', 'Unknown')}") | |
| print(f"๐ [DEBUG] ๋ฉํฐ๋ชจ๋ฌ ์ง์: {getattr(current_profile, 'multimodal', False)}") | |
| print(f"๐ [DEBUG] ์ ๋ ฅ ํ๋กฌํํธ: {prompt}") | |
| print(f"๐ [DEBUG] ์ ๋ ฅ ํ๋กฌํํธ ๊ธธ์ด: {len(prompt)}") | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์กด์ฌ ์ฌ๋ถ: {image_data_list is not None}") | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ๋ฐ์ดํฐ ๊ฐ์: {len(image_data_list) if image_data_list else 0}") | |
| print(f"๐ [DEBUG] ์ค์ ์ด๋ฏธ์ง ๋ฐ์ดํฐ ๊ฐ์: {len([img for img in image_data_list if img]) if image_data_list else 0}") | |
| image_processed = False | |
| all_pixel_values = [] | |
| combined_image_metas = None | |
| # --- 1. ์ด๋ฏธ์ง ์ฒ๋ฆฌ (๊ณต์ ๋ฐฉ์) --- | |
| # ๐ RAG์์ ์ถ์ถ๋ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ ํฌํจ | |
| all_image_data = [] | |
| if image_data_list and len([img for img in image_data_list if img]) > 0: | |
| all_image_data.extend(image_data_list) | |
| print(f"๐ [DEBUG] ์ง์ ์ ๋ฌ๋ ์ด๋ฏธ์ง {len(image_data_list)}๊ฐ ์ถ๊ฐ") | |
| # ๐ RAG์์ ์ถ์ถ๋ ์ด๋ฏธ์ง ๋ฐ์ดํฐ๋ ํ์ฌ ๊ตฌํ์์ ์ ๊ฑฐ๋จ (์ ์ญ ๋ณ์ ๋ฌธ์ ํด๊ฒฐ) | |
| if all_image_data and len([img for img in all_image_data if img]) > 0 and getattr(current_profile, 'multimodal', False): | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์์ - ์ด ์ด๋ฏธ์ง ๊ฐ์: {len([img for img in all_image_data if img])}") | |
| # ๐ ๊ณต์ ๋ฐฉ์: ๊ฐ๋จํ ์ด๋ฏธ์ง ์ฒ๋ฆฌ | |
| max_images = min(len(all_image_data), 4) | |
| logger.info(f"๐ผ๏ธ ๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ ์์... (์ด๋ฏธ์ง {max_images}๊ฐ)") | |
| try: | |
| metas_list = [] | |
| for idx, image_bytes in enumerate(all_image_data[:max_images]): | |
| if image_bytes: | |
| try: | |
| pil_image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| # ๐ ๊ณต์ ์ด๋ฏธ์ง ํ๋ก์ธ์ ์ฌ์ฉ | |
| if processor and hasattr(processor, 'image_processor'): | |
| processed = processor.image_processor(pil_image) | |
| all_pixel_values.append(processed["pixel_values"]) | |
| metas_list.append(processed.get("image_meta", {})) | |
| else: | |
| logger.warning(f"โ ๏ธ ์ด๋ฏธ์ง ํ๋ก์ธ์๋ฅผ ์ฐพ์ ์ ์์") | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ ์ด๋ฏธ์ง {idx} ์ฒ๋ฆฌ ์คํจ: {e}") | |
| # ๐ ๋ฉํ๋ฐ์ดํฐ ํตํฉ (๊ณต์ ๋ฐฉ์) | |
| if metas_list: | |
| combined_image_metas = {} | |
| for key in metas_list[0].keys(): | |
| combined_image_metas[key] = [meta[key] for meta in metas_list if key in meta] | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ๋ฉํ๋ฐ์ดํฐ: {combined_image_metas}") | |
| else: | |
| combined_image_metas = {} | |
| except Exception as e: | |
| logger.error(f"โ ์ด๋ฏธ์ง ์ ์ฒ๋ฆฌ ์คํจ: {e}") | |
| combined_image_metas = {} | |
| # --- 2. ํ๋กฌํํธ ๊ตฌ์ฑ --- | |
| print(f"๐ [DEBUG] ํ๋กฌํํธ ๊ตฌ์ฑ ์์") | |
| # ์ปจํ ์คํธ ํตํฉ (๋ํ ๊ธฐ๋ก + RAG ๊ฒ์ ๊ฒฐ๊ณผ ํฌํจ) - ๋ชจ๋ธ๋ณ ์ต์ ํ | |
| context_prompt = "" | |
| if use_context and session_id: | |
| try: | |
| # 1. ๋ํ ๊ธฐ๋ก ์ปจํ ์คํธ | |
| context = context_manager.get_context_for_model( | |
| current_profile.model_name, | |
| session_id | |
| ) | |
| if context and len(context.strip()) > 0: | |
| context_prompt = context + "\n\n" | |
| print(f"๐ [DEBUG] ๋ํ ์ปจํ ์คํธ ํฌํจ๋จ - ๊ธธ์ด: {len(context_prompt)} (์ธ์ : {session_id})") | |
| # 2. RAG ๊ฒ์ ๊ฒฐ๊ณผ ์ปจํ ์คํธ (PDF ๋ด์ฉ ํฌํจ) | |
| try: | |
| # ๏ฟฝ๏ฟฝ ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ ์์คํ ์ ์ฌ์ฉํ RAG ์ปจํ ์คํธ ๋ก๋ | |
| rag_context = "" | |
| # ๐ ์ฌ์ฉ์ ์ค์ ํ์ธ | |
| from lily_llm_core.user_memory_manager import user_memory_manager | |
| keep_memory = user_memory_manager.get_memory_setting(user_id, "keep_memory_on_room_change") | |
| if keep_memory: | |
| # ๋ฉ๋ชจ๋ฆฌ ์ ์ง ๋ชจ๋ - ๊ธฐ์กด ๋ก์ง ์คํ | |
| print(f"๐ [DEBUG] ์ฌ์ฉ์ {user_id} ๋ฉ๋ชจ๋ฆฌ ์ ์ง ๋ชจ๋ - RAG ์ปจํ ์คํธ ๋ก๋") | |
| # ํตํฉ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ์์์ AI์ฉ ์ปจํ ์คํธ ์์ฑ | |
| ai_context = integrated_memory_manager.get_context_for_ai( | |
| user_id=user_id, | |
| room_id=room_id, | |
| session_id=session_id, | |
| include_user_memory=True, | |
| include_room_context=True, | |
| include_session_history=False # ํ์ฌ ๋ํ๋ ๋ณ๋๋ก ์ฒ๋ฆฌ | |
| ) | |
| if ai_context: | |
| rag_context += f"\n\n๐ ๋ฉ๋ชจ๋ฆฌ ์ปจํ ์คํธ:\n{ai_context}\n" | |
| print(f"๐ [DEBUG] ๋ฉ๋ชจ๋ฆฌ ์ปจํ ์คํธ ํฌํจ๋จ - ๊ธธ์ด: {len(ai_context)}") | |
| # ๊ธฐ์กด RAG ์์คํ ์์ ๋ฌธ์ ๋ด์ฉ ๊ฐ์ ธ์ค๊ธฐ (room_id ๊ธฐ๋ฐ) | |
| try: | |
| # ์ฑํ ๋ฐฉ๋ณ ๋ฌธ์ ์ปจํ ์คํธ ์กฐํ | |
| room_context = integrated_memory_manager.room_context_manager.get_room_context(room_id) | |
| if room_context and room_context.documents: | |
| rag_context += "\n\n๐ ์ ๋ก๋๋ ๋ฌธ์ ๋ชฉ๋ก:\n" | |
| for doc in room_context.documents[-3:]: # ์ต๊ทผ 3๊ฐ๋ง | |
| # ๋์ ๋๋ฆฌ์ ๊ฐ์ฒด ๋ชจ๋ ์ฒ๋ฆฌ | |
| if isinstance(doc, dict): | |
| filename = doc.get('filename', 'unknown') | |
| doc_type = doc.get('document_type', 'unknown') | |
| page_count = doc.get('page_count', 0) | |
| else: | |
| filename = getattr(doc, 'filename', 'unknown') | |
| doc_type = getattr(doc, 'document_type', 'unknown') | |
| page_count = getattr(doc, 'page_count', 0) | |
| rag_context += f" - {filename} ({doc_type}, {page_count}ํ์ด์ง)\n" | |
| print(f"๐ [DEBUG] ์ฑํ ๋ฐฉ {room_id}์ ๋ฌธ์ {len(room_context.documents)}๊ฐ ๋ฐ๊ฒฌ") | |
| except Exception as e: | |
| print(f"โ ๏ธ ์ฑํ ๋ฐฉ ๋ฌธ์ ์ปจํ ์คํธ ๋ก๋ ์คํจ: {e}") | |
| # ๐ ๋ฌธ์ ๋ด์ฉ ์์ฒด๋ ๋ก๋ํ์ง ์์ (ํด๋ณ ์ด๊ธฐํ) | |
| # ์ด์ ํด์์ ์ฒจ๋ถ๋ ๋ฌธ์์ ์ค์ ๋ด์ฉ์ AI ์ปจํ ์คํธ์ ํฌํจํ์ง ์์ | |
| print(f"๏ฟฝ๏ฟฝ [DEBUG] ๋ฌธ์ ๋ด์ฉ ๋ก๋ ๊ฑด๋๋ฐ๊ธฐ - ํด๋ณ ์ด๊ธฐํ ์ ์ฉ") | |
| # ๏ฟฝ๏ฟฝ ํ์ฌ ํด์์๋ง ๋ฌธ์ ์ ๋ณด ํ์ (์ค์ ๋ด์ฉ์ ๋ก๋ํ์ง ์์) | |
| if rag_context: | |
| context_prompt += rag_context | |
| print(f"๐ [DEBUG] ๋ฌธ์ ๋ชฉ๋ก๋ง ํ์ - ์ค์ ๋ด์ฉ ๋ก๋ ์ํจ (ํด๋ณ ์ด๊ธฐํ)") | |
| except Exception as e: | |
| print(f"โ ๏ธ [DEBUG] RAG ์ปจํ ์คํธ ์ฒ๋ฆฌ ์คํจ: {e}") | |
| if not context_prompt: | |
| print(f"๏ฟฝ๏ฟฝ [DEBUG] ์ปจํ ์คํธ ์์ ๋๋ ๋น์ด์์ (์ธ์ : {session_id})") | |
| except Exception as e: | |
| print(f"โ ๏ธ [DEBUG] ์ปจํ ์คํธ ๋ก๋ ์คํจ: {e} (์ธ์ : {session_id})") | |
| context_prompt = "" | |
| # formatted_prompt ์ด๊ธฐํ | |
| formatted_prompt = None | |
| # ๐ ๋ฉํฐ๋ชจ๋ฌ ํ๋กฌํํธ ๊ตฌ์ฑ (๊ณต์ ๋ฐฉ์) | |
| if all_pixel_values and len(all_pixel_values) > 0: | |
| # ๐ ๊ณต์ Kanana ํ์: Human: <image> ํ ์คํธ | |
| # ์ด๋ฏธ์ง ํ ํฐ์ encode_prompt์์ ์๋์ผ๋ก ์ฒ๋ฆฌ๋จ | |
| formatted_prompt = f"Human: <image>{prompt}" | |
| print(f"๐ [DEBUG] ๋ฉํฐ๋ชจ๋ฌ ํ๋กฌํํธ ๊ตฌ์ฑ (๊ณต์ ํ์): {formatted_prompt}") | |
| image_processed = True | |
| else: | |
| image_tokens = "" | |
| image_processed = False | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ์์ - ํ ์คํธ-only ๋ชจ๋") | |
| # ํ ์คํธ-only ๋ชจ๋ธ์ฉ ํ๋กฌํํธ ๊ตฌ์ฑ (์ปจํ ์คํธ ํฌํจ) | |
| if hasattr(current_profile, 'format_prompt'): | |
| # Polyglot ๋ชจ๋ธ์ผ ๋๋ format_prompt ๋ฉ์๋ ์ฌ์ฉ (์ปจํ ์คํธ ์ง์) | |
| if "polyglot" in current_profile.model_name.lower(): | |
| # ์ปจํ ์คํธ์ ํ๋กฌํํธ๋ฅผ ํจ๊ป ์ ๋ฌ | |
| formatted_prompt = current_profile.format_prompt(prompt, context_prompt) | |
| else: | |
| # ๋ค๋ฅธ ๋ชจ๋ธ์ ๊ธฐ์กด ๋ฐฉ์ ์ฌ์ฉ | |
| base_prompt = current_profile.format_prompt(prompt) | |
| if context_prompt: | |
| formatted_prompt = context_prompt + base_prompt | |
| else: | |
| formatted_prompt = base_prompt | |
| print(f"๐ [DEBUG] ํ๋กํ format_prompt ์ฌ์ฉ (์ปจํ ์คํธ ํฌํจ): {formatted_prompt}") | |
| else: | |
| # ๊ธฐ๋ณธ ํ๋กฌํํธ (fallback) - ์ปจํ ์คํธ ํฌํจ | |
| # Polyglot ๋ชจ๋ธ์ <|im_start|> ํ๊ทธ๋ฅผ ์ ๋๋ก ์ฒ๋ฆฌํ์ง ๋ชปํจ | |
| if "polyglot" in current_profile.model_name.lower(): | |
| base_prompt = f"### ์ฌ์ฉ์:\n{prompt}\n\n### ์ฑ๋ด:\n" | |
| else: | |
| base_prompt = f"<|im_start|>user\n{prompt}<|im_end|>\n<|im_start|>assistant\n" | |
| if context_prompt: | |
| formatted_prompt = context_prompt + base_prompt | |
| else: | |
| formatted_prompt = base_prompt | |
| print(f"๐ [DEBUG] ๊ธฐ๋ณธ ํ๋กฌํํธ ์ฌ์ฉ (์ปจํ ์คํธ ํฌํจ): {formatted_prompt}") | |
| print(f"๐ [DEBUG] ํ๋กฌํํธ ๊ตฌ์ฑ ์๋ฃ - ๊ธธ์ด: {len(formatted_prompt) if formatted_prompt else 0}") | |
| print(f"๐ [DEBUG] ์ต์ข ํ๋กฌํํธ: {formatted_prompt}") | |
| # --- 3. ํ ํฌ๋์ด์ง --- | |
| print(f"๐ [DEBUG] ํ ํฌ๋์ด์ง ์์") | |
| t_tok_start = time.time() | |
| if not all_image_data or len([img for img in all_image_data if img]) == 0: | |
| # ํ ์คํธ-only ๊ณ ์ ๊ฒฝ๋ก (๋ ๋น ๋ฆ) | |
| print(f"๐ [DEBUG] ํ ์คํธ-only ํ ํฌ๋์ด์ง ๊ฒฝ๋ก") | |
| print(f"๐ [DEBUG] ์ฌ์ฉํ ํ๋กฌํํธ: {formatted_prompt}") | |
| inputs = tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048, | |
| ) | |
| if 'token_type_ids' in inputs: | |
| del inputs['token_type_ids'] | |
| print(f"๐ [DEBUG] token_type_ids ์ ๊ฑฐ๋จ") | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| print(f"๐ [DEBUG] ํ ํฌ๋์ด์ ์ถ๋ ฅ: {list(inputs.keys())}") | |
| else: | |
| # ๋ฉํฐ๋ชจ๋ฌ(Lite): Kanana ์ ์ฉ encode_prompt๋ก -1 ํ ํฐ ์๋ฆฌ ์์ฑ (ํ์) | |
| print(f"๐ [DEBUG] ๋ฉํฐ๋ชจ๋ฌ ํ ํฌ๋์ด์ง ๊ฒฝ๋ก") | |
| print(f"๐ [DEBUG] combined_image_metas: {combined_image_metas}") | |
| print(f"๐ [DEBUG] ์ด ์ด๋ฏธ์ง ๊ฐ์: {len(all_image_data)}") | |
| if hasattr(tokenizer, 'encode_prompt'): | |
| print(f"๐ [DEBUG] encode_prompt ๋ฉ์๋ ์ฌ์ฉ") | |
| # ๐ ๋ฉํ๋ฐ์ดํฐ ๊ฒ์ฆ ๋ฐ ์์ ํ | |
| safe_image_meta = {} | |
| if combined_image_metas: | |
| # image_token_thw ๋ฐฐ์ด ๊ธธ์ด ๊ฒ์ฆ | |
| if 'image_token_thw' in combined_image_metas: | |
| image_token_thw = combined_image_metas['image_token_thw'] | |
| if isinstance(image_token_thw, list) and len(image_token_thw) > 0: | |
| # ๋ฐฐ์ด ๊ธธ์ด๊ฐ ์ด๋ฏธ์ง ๊ฐ์์ ์ผ์นํ๋์ง ํ์ธ | |
| if len(image_token_thw) == len(all_pixel_values): | |
| # ๐ ์ถ๊ฐ ๊ฒ์ฆ: ๊ฐ ๋ฐฐ์ด ์์๊ฐ ์ ํจํ์ง ํ์ธ | |
| valid_meta = True | |
| for i, thw in enumerate(image_token_thw): | |
| if not isinstance(thw, (list, tuple)) or len(thw) != 3: | |
| print(f"โ ๏ธ [DEBUG] ๋ฉํ๋ฐ์ดํฐ ์์ {i}๊ฐ ์ ํจํ์ง ์์: {thw}") | |
| valid_meta = False | |
| break | |
| if valid_meta: | |
| safe_image_meta = combined_image_metas | |
| print(f"๐ [DEBUG] ๋ฉํ๋ฐ์ดํฐ ๊ฒ์ฆ ํต๊ณผ: {len(image_token_thw)}๊ฐ ์ด๋ฏธ์ง") | |
| else: | |
| print(f"โ ๏ธ [DEBUG] ๋ฉํ๋ฐ์ดํฐ ์์ ๊ฒ์ฆ ์คํจ, ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ") | |
| safe_image_meta = { | |
| 'image_token_thw': [[1, 1, 1]] * len(all_pixel_values), | |
| 'vision_grid_thw': [[1, 1, 1]] * len(all_pixel_values) | |
| } | |
| else: | |
| print(f"โ ๏ธ [DEBUG] ๋ฉํ๋ฐ์ดํฐ ๋ถ์ผ์น: ์ด๋ฏธ์ง {len(all_pixel_values)}๊ฐ, ๋ฉํ {len(image_token_thw)}๊ฐ") | |
| # ์์ ํ ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ | |
| safe_image_meta = { | |
| 'image_token_thw': [[1, 1, 1]] * len(all_pixel_values), | |
| 'vision_grid_thw': [[1, 1, 1]] * len(all_pixel_values) | |
| } | |
| else: | |
| print(f"โ ๏ธ [DEBUG] image_token_thw๊ฐ ์ ํจํ์ง ์์, ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ") | |
| safe_image_meta = { | |
| 'image_token_thw': [[1, 1, 1]] * len(all_pixel_values), | |
| 'vision_grid_thw': [[1, 1, 1]] * len(all_pixel_values) | |
| } | |
| else: | |
| print(f"โ ๏ธ [DEBUG] image_token_thw ์์, ๊ธฐ๋ณธ๊ฐ ์์ฑ") | |
| safe_image_meta = { | |
| 'image_token_thw': [[1, 1, 1]] * len(all_pixel_values), | |
| 'vision_grid_thw': [[1, 1, 1]] * len(all_pixel_values) | |
| } | |
| else: | |
| print(f"โ ๏ธ [DEBUG] combined_image_metas ์์, ๊ธฐ๋ณธ๊ฐ ์์ฑ") | |
| safe_image_meta = { | |
| 'image_token_thw': [[1, 1, 1]] * len(all_pixel_values), | |
| 'vision_grid_thw': [[1, 1, 1]] * len(all_pixel_values) | |
| } | |
| print(f"๐ [DEBUG] ์์ ํ๋ ๋ฉํ๋ฐ์ดํฐ: {safe_image_meta}") | |
| # ๐ ์์ ํ ๋ฉํ๋ฐ์ดํฐ๋ก encode_prompt ํธ์ถ | |
| try: | |
| # ๐ ์ถ๊ฐ ์์ ์ฅ์น: ๋ฉํ๋ฐ์ดํฐ ๋ณต์ฌ๋ณธ ์์ฑ | |
| final_meta = {} | |
| for key, value in safe_image_meta.items(): | |
| if isinstance(value, list): | |
| final_meta[key] = value.copy() # ๋ณต์ฌ๋ณธ ์์ฑ | |
| else: | |
| final_meta[key] = value | |
| print(f"๐ [DEBUG] ์ต์ข ๋ฉํ๋ฐ์ดํฐ: {final_meta}") | |
| # ๐ ๊ณต์ ๋ฐฉ์: max_length ํ๋ผ๋ฏธํฐ ์ถ๊ฐ | |
| inputs = tokenizer.encode_prompt( | |
| prompt=formatted_prompt, | |
| max_length=2048, # ๊ณต์ ์ฝ๋์ ๋์ผ | |
| image_meta=final_meta | |
| ) | |
| print(f"๐ [DEBUG] encode_prompt ์ถ๋ ฅ: {list(inputs.keys())}") | |
| # ๐ encode_prompt ์ถ๋ ฅ ์ ๊ทํ (seq_length ์ ๊ฑฐ) | |
| if 'seq_length' in inputs: | |
| print(f"๐ [DEBUG] seq_length ์ ๊ฑฐ๋จ") | |
| del inputs['seq_length'] | |
| # ๐ input_ids ์์ ํ๊ฒ ์ถ์ถ (๊ณต์ ๋ฐฉ์) | |
| if isinstance(inputs['input_ids'], tuple): | |
| print(f"๐ [DEBUG] input_ids๊ฐ ํํ์: {len(inputs['input_ids'])}๊ฐ ์์") | |
| input_ids = inputs['input_ids'][0] # ์ฒซ ๋ฒ์งธ ์์ ์ฌ์ฉ | |
| print(f"๐ [DEBUG] input_ids ํํ์์ ์ฒซ ๋ฒ์งธ ์์ ์ถ์ถ: {input_ids.shape}") | |
| else: | |
| input_ids = inputs['input_ids'] | |
| # ๐ attention_mask๋ ์์ ํ๊ฒ ์ถ์ถ | |
| if isinstance(inputs['attention_mask'], tuple): | |
| print(f"๐ [DEBUG] attention_mask๊ฐ ํํ์: {len(inputs['attention_mask'])}๊ฐ ์์") | |
| attention_mask = inputs['attention_mask'][0] # ์ฒซ ๋ฒ์งธ ์์ ์ฌ์ฉ | |
| print(f"๐ [DEBUG] attention_mask ํํ์์ ์ฒซ ๋ฒ์งธ ์์ ์ถ์ถ: {attention_mask.shape}") | |
| else: | |
| attention_mask = inputs['attention_mask'] | |
| # ๐ ์ต์ข ๊ฒ์ฆ | |
| print(f"๐ [DEBUG] ์ต์ข input_ids ํ์ : {type(input_ids)}, shape: {input_ids.shape}") | |
| print(f"๐ [DEBUG] ์ต์ข attention_mask ํ์ : {type(attention_mask)}, shape: {attention_mask.shape}") | |
| except Exception as e: | |
| print(f"โ [DEBUG] encode_prompt ์คํจ: {e}, ํด๋ฐฑ ์ฌ์ฉ") | |
| # ํด๋ฐฑ: ๊ธฐ๋ณธ ํ ํฌ๋์ด์ ์ฌ์ฉ | |
| inputs = tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=2048, | |
| ) | |
| if 'token_type_ids' in inputs: | |
| del inputs['token_type_ids'] | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| else: | |
| # ์์ ํด๋ฐฑ | |
| print(f"๐ [DEBUG] ๊ธฐ๋ณธ ํ ํฌ๋์ด์ ์ฌ์ฉ (ํด๋ฐฑ)") | |
| inputs = tokenizer( | |
| formatted_prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_lengt=2048, | |
| ) | |
| if 'token_type_ids' in inputs: | |
| del inputs['token_type_ids'] | |
| print(f"๐ [DEBUG] token_type_ids ์ ๊ฑฐ๋จ (ํด๋ฐฑ)") | |
| input_ids = inputs['input_ids'] | |
| attention_mask = inputs['attention_mask'] | |
| print(f"๐ [DEBUG] ๊ธฐ๋ณธ ํ ํฌ๋์ด์ ์ถ๋ ฅ: {list(inputs.keys())}") | |
| t_tok_end = time.time() | |
| print(f"๐ [DEBUG] ํ ํฌ๋์ด์ง ์๋ฃ - ์์์๊ฐ: {t_tok_end - t_tok_start:.3f}์ด") | |
| # ๐ input_ids ์์ ํ๊ฒ ์ฒ๋ฆฌ | |
| if isinstance(input_ids, tuple): | |
| print(f"๐ [DEBUG] input_ids๊ฐ ํํ์: {len(input_ids)}๊ฐ ์์") | |
| input_ids = input_ids[0] # ์ฒซ ๋ฒ์งธ ์์ ์ฌ์ฉ | |
| print(f"๐ [DEBUG] input_ids ํํ์์ ์ฒซ ๋ฒ์งธ ์์ ์ถ์ถ: {input_ids.shape}") | |
| # ๐ 1์ฐจ์ ํ ์๋ฅผ 2์ฐจ์์ผ๋ก reshape | |
| if len(input_ids.shape) == 1: | |
| print(f"๐ [DEBUG] 1์ฐจ์ ํ ์๋ฅผ 2์ฐจ์์ผ๋ก reshape: {input_ids.shape} -> (1, {input_ids.shape[0]})") | |
| input_ids = input_ids.unsqueeze(0) # (seq_len,) -> (1, seq_len) | |
| # ๐ attention_mask๋ ๋์ผํ๊ฒ ์ฒ๋ฆฌ | |
| if len(attention_mask.shape) == 1: | |
| print(f"๐ [DEBUG] attention_mask 1์ฐจ์์ 2์ฐจ์์ผ๋ก reshape: {attention_mask.shape} -> (1, {attention_mask.shape[0]})") | |
| attention_mask = attention_mask.unsqueeze(0) # (seq_len,) -> (1, seq_len) | |
| print(f"๐ [DEBUG] ์ต์ข input_ids shape: {input_ids.shape}") | |
| print(f"๐ [DEBUG] ์ ๋ ฅ ํ ํฐ ์: {input_ids.shape[1]}") | |
| # --- 4. ์์ฑ ์ค์ --- | |
| print(f"๐ [DEBUG] ์์ฑ ์ค์ ๊ตฌ์ฑ ์์") | |
| gen_config = current_profile.get_generation_config() | |
| # config ํ์ผ์ ๋ช ์๋ eos, pad, bos ํ ํฐ id ๊ธฐ๋ณธ๊ฐ์ผ๋ก ์ฑ์ฐ๊ธฐ | |
| if 'eos_token_id' not in gen_config or gen_config['eos_token_id'] is None: | |
| gen_config['eos_token_id'] = tokenizer.eos_token_id | |
| if 'pad_token_id' not in gen_config or gen_config['pad_token_id'] is None: | |
| gen_config['pad_token_id'] = tokenizer.pad_token_id or tokenizer.eos_token_id | |
| # ํ์ํ ๊ฒฝ์ฐ bos_token_id ๋ ์ค์ (generate ํจ์์ ๋ฐ๋ผ ๋ค๋ฆ) | |
| if 'bos_token_id' not in gen_config and hasattr(tokenizer, 'bos_token_id'): | |
| gen_config['bos_token_id'] = tokenizer.bos_token_id | |
| # max_new_tokens, temperature ๋ฑ API ์ธ์ ๋ฐ์์ ๋ฎ์ด์ฐ๊ธฐ | |
| if max_length is not None: | |
| gen_config['max_new_tokens'] = max_length | |
| if temperature is not None: | |
| gen_config['temperature'] = temperature | |
| if top_p is not None: | |
| gen_config['top_p'] = top_p | |
| if do_sample is not None: | |
| gen_config['do_sample'] = do_sample | |
| print(f"๐ [DEBUG] ์์ฑ ์ค์ : {gen_config}") | |
| # --- 5. ์ค์ ์ถ๋ก ์คํ --- | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์ถ๋ก ์์") | |
| t_gen_start = time.time() | |
| try: | |
| # ๋ชจ๋ธ ์ํ ํ์ธ | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ๋๋ฐ์ด์ค: {model.device}") | |
| print(f"๐ [DEBUG] ์ ๋ ฅ ํ ์ ๋๋ฐ์ด์ค: {input_ids.device}") | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ํ์ : {type(model)}") | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์ํ: {'eval' if model.training == False else 'training'}") | |
| print(f"๐ [DEBUG] ์ ๋ ฅ ํ ์ shape: {input_ids.shape}") | |
| print(f"๐ [DEBUG] attention_mask shape: {attention_mask.shape}") | |
| print(f"๐ [DEBUG] all_pixel_values ์กด์ฌ ์ฌ๋ถ: {all_pixel_values is not None}") | |
| print(f"๐ [DEBUG] all_pixel_values ๊ธธ์ด: {len(all_pixel_values) if all_pixel_values else 0}") | |
| # ์ ๋ ฅ ํ ์๋ฅผ ๋ชจ๋ธ ๋๋ฐ์ด์ค๋ก ์ด๋ | |
| if input_ids.device != model.device: | |
| print(f"๐ [DEBUG] ์ ๋ ฅ ํ ์๋ฅผ ๋ชจ๋ธ ๋๋ฐ์ด์ค๋ก ์ด๋: {input_ids.device} -> {model.device}") | |
| input_ids = input_ids.to(model.device) | |
| attention_mask = attention_mask.to(model.device) | |
| # ๐ torch import ๋ฌธ์ ํด๊ฒฐ | |
| import torch | |
| with torch.no_grad(): | |
| if all_pixel_values and len(all_pixel_values) > 0: | |
| # ๋ฉํฐ๋ชจ๋ฌ: ์ด๋ฏธ์ง์ ํ ์คํธ ํจ๊ป ์ฒ๋ฆฌ | |
| print(f"๐ [DEBUG] ๋ฉํฐ๋ชจ๋ฌ ์ถ๋ก ์คํ") | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ ์ ๊ฐ์: {len(all_pixel_values)}") | |
| # ์ด๋ฏธ์ง ํ ์๋ ๋๋ฐ์ด์ค ํ์ธ | |
| pixel_values = torch.cat(all_pixel_values, dim=0) | |
| print(f"๐ [DEBUG] ๊ฒฐํฉ๋ ์ด๋ฏธ์ง ํ ์ shape: {pixel_values.shape}") | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ ์ dtype: {pixel_values.dtype}") | |
| # ๐ ๋ชจ๋ธ๊ณผ ๋์ผํ dtype์ผ๋ก ๋ณํ (์ฑ๋ฅ ์ต์ ํ) | |
| if hasattr(model, 'dtype'): | |
| target_dtype = model.dtype | |
| if pixel_values.dtype != target_dtype: | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ ์ dtype ๋ณํ: {pixel_values.dtype} -> {target_dtype}") | |
| pixel_values = pixel_values.to(dtype=target_dtype) | |
| else: | |
| # ๐ ๋ชจ๋ธ dtype์ ์ ์ ์๋ ๊ฒฝ์ฐ bfloat16 ์ฌ์ฉ (Kanana ๋ชจ๋ธ ๊ธฐ๋ณธ๊ฐ) | |
| target_dtype = torch.bfloat16 | |
| if pixel_values.dtype != target_dtype: | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ ์ dtype ๋ณํ: {pixel_values.dtype} -> {target_dtype}") | |
| pixel_values = pixel_values.to(dtype=target_dtype) | |
| if pixel_values.device != model.device: | |
| print(f"๐ [DEBUG] ์ด๋ฏธ์ง ํ ์๋ฅผ ๋ชจ๋ธ ๋๋ฐ์ด์ค๋ก ์ด๋: {pixel_values.device} -> {model.device}") | |
| pixel_values = pixel_values.to(model.device) | |
| print(f"๐ [DEBUG] ์ต์ข ์ด๋ฏธ์ง ํ ์ ๋๋ฐ์ด์ค: {pixel_values.device}") | |
| print(f"๐ [DEBUG] ์ต์ข ์ด๋ฏธ์ง ํ ์ dtype: {pixel_values.dtype}") | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์์ - ๋ฉํฐ๋ชจ๋ฌ") | |
| # LoRA ์ด๋ํฐ๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ์ธ์ง ํ์ธ | |
| if LORA_AVAILABLE and lora_manager and hasattr(lora_manager, 'current_adapter_name') and lora_manager.current_adapter_name: | |
| print(f"๐ [DEBUG] LoRA ์ด๋ํฐ ์ ์ฉ๋จ (๋ฉํฐ๋ชจ๋ฌ): {lora_manager.current_adapter_name}") | |
| # LoRA๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ ์ฌ์ฉ | |
| lora_model = lora_manager.get_model() | |
| if lora_model: | |
| print(f"๐ [DEBUG] LoRA ๋ชจ๋ธ๋ก ๋ฉํฐ๋ชจ๋ฌ ์์ฑ ์คํ") | |
| # ๐ image_metas ํ๋ผ๋ฏธํฐ ์ถ๊ฐ (๊ณต์ ๋ฐฉ์) | |
| # ๐ ๋ฉํ๋ฐ์ดํฐ๋ฅผ ๊ณต์ ๊ตฌ์กฐ๋ก ๋ณํ (๋ชจ๋ธ ์๊ตฌ์ฌํญ) | |
| import torch | |
| processed_image_metas = {} | |
| # ๐ ๊ณต์ ๋ฐฉ์: vision_grid_thw๋ฅผ ํ ์๋ก ๋ณํ | |
| if 'vision_grid_thw' in combined_image_metas: | |
| vision_grid = combined_image_metas['vision_grid_thw'] | |
| if isinstance(vision_grid, list): | |
| # ๐ Kanana ๋ชจ๋ธ ์๊ตฌ์ฌํญ: (T, H, W) ํํ์ 3์ฐจ์ ํ ์ | |
| if len(vision_grid) == 1 and len(vision_grid[0]) == 3: | |
| # [(1, 34, 52)] -> (1, 34, 52) ํ ์๋ก ๋ณํ | |
| t, h, w = vision_grid[0] | |
| # ๐ 3์ฐจ์ ํ ์๋ก ๋ณํ: (1, H, W) ํํ | |
| processed_image_metas['vision_grid_thw'] = torch.tensor([[t, h, w]], dtype=torch.long) | |
| print(f"๐ [DEBUG] vision_grid_thw ํ ์ ๋ณํ: {vision_grid} -> {processed_image_metas['vision_grid_thw'].shape}") | |
| else: | |
| # ๐ ๋ค๋ฅธ ํํ์ ๊ฒฝ์ฐ ์๋ณธ ์ ์ง | |
| processed_image_metas['vision_grid_thw'] = torch.tensor(vision_grid, dtype=torch.long) | |
| print(f"๐ [DEBUG] vision_grid_thw ํ ์ ๋ณํ (๊ธฐ๋ณธ): {vision_grid} -> {processed_image_metas['vision_grid_thw'].shape}") | |
| else: | |
| processed_image_metas['vision_grid_thw'] = vision_grid | |
| # ๐ ๋ค๋ฅธ ๋ฉํ๋ฐ์ดํฐ๋ ๊ทธ๋๋ก ์ ์ง | |
| for key, value in combined_image_metas.items(): | |
| if key != 'vision_grid_thw': | |
| processed_image_metas[key] = value | |
| generate_kwargs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'pixel_values': pixel_values, | |
| 'image_metas': processed_image_metas, # ๐ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ๋ฉํ๋ฐ์ดํฐ | |
| **gen_config | |
| } | |
| print(f"๐ [DEBUG] LoRA ๋ชจ๋ธ ์์ฑ ํ๋ผ๋ฏธํฐ: {list(generate_kwargs.keys())}") | |
| print(f"๐ [DEBUG] ์ฒ๋ฆฌ๋ image_metas: {list(processed_image_metas.keys())}") | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์์... (ํ์์์ ์์)") | |
| # ๐ ์์ฑ ์ ์ต์ข ๊ฒ์ฆ | |
| print(f"๐ [DEBUG] ์ต์ข ํ๋ผ๋ฏธํฐ ๊ฒ์ฆ:") | |
| print(f" - input_ids: {input_ids.shape}, dtype: {input_ids.dtype}") | |
| print(f" - attention_mask: {attention_mask.shape}, dtype: {attention_mask.dtype}") | |
| print(f" - pixel_values: {pixel_values.shape}, dtype: {pixel_values.dtype}") | |
| print(f" - vision_grid_thw: {processed_image_metas.get('vision_grid_thw', 'None')}") | |
| generated_ids = lora_model.generate(**generate_kwargs) | |
| else: | |
| print(f"โ ๏ธ [DEBUG] LoRA ๋ชจ๋ธ์ ๊ฐ์ ธ์ฌ ์ ์์, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ") | |
| # ๐ image_metas ํ๋ผ๋ฏธํฐ ์ถ๊ฐ (๊ณต์ ๋ฐฉ์) | |
| # ๐ ๋ฉํ๋ฐ์ดํฐ๋ฅผ ๊ณต์ ๊ตฌ์กฐ๋ก ๋ณํ (๋ชจ๋ธ ์๊ตฌ์ฌํญ) | |
| processed_image_metas = {} | |
| # ๐ ๊ณต์ ๋ฐฉ์: vision_grid_thw๋ฅผ ํ ์๋ก ๋ณํ | |
| if 'vision_grid_thw' in combined_image_metas: | |
| vision_grid = combined_image_metas['vision_grid_thw'] | |
| if isinstance(vision_grid, list): | |
| # ๐ Kanana ๋ชจ๋ธ ์๊ตฌ์ฌํญ: (T, H, W) ํํ์ 3์ฐจ์ ํ ์ | |
| if len(vision_grid) == 1 and len(vision_grid[0]) == 3: | |
| # [(1, 34, 52)] -> (1, 34, 52) ํ ์๋ก ๋ณํ | |
| t, h, w = vision_grid[0] | |
| # ๐ 3์ฐจ์ ํ ์๋ก ๋ณํ: (1, H, W) ํํ | |
| processed_image_metas['vision_grid_thw'] = torch.tensor([[t, h, w]], dtype=torch.long) | |
| print(f"๐ [DEBUG] vision_grid_thw ํ ์ ๋ณํ: {vision_grid} -> {processed_image_metas['vision_grid_thw'].shape}") | |
| else: | |
| # ๐ ๋ค๋ฅธ ํํ์ ๊ฒฝ์ฐ ์๋ณธ ์ ์ง | |
| processed_image_metas['vision_grid_thw'] = torch.tensor(vision_grid, dtype=torch.long) | |
| print(f"๐ [DEBUG] vision_grid_thw ํ ์ ๋ณํ (๊ธฐ๋ณธ): {vision_grid} -> {processed_image_metas['vision_grid_thw'].shape}") | |
| else: | |
| processed_image_metas['vision_grid_thw'] = vision_grid | |
| # ๐ ๋ค๋ฅธ ๋ฉํ๋ฐ์ดํฐ๋ ๊ทธ๋๋ก ์ ์ง | |
| for key, value in combined_image_metas.items(): | |
| if key != 'vision_grid_thw': | |
| processed_image_metas[key] = value | |
| generate_kwargs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'pixel_values': pixel_values, | |
| 'image_metas': processed_image_metas, # ๐ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ๋ฉํ๋ฐ์ดํฐ | |
| **gen_config | |
| } | |
| print(f"๐ [DEBUG] ๊ธฐ๋ณธ ๋ชจ๋ธ ์์ฑ ํ๋ผ๋ฏธํฐ: {list(generate_kwargs.keys())}") | |
| print(f"๐ [DEBUG] ์ฒ๋ฆฌ๋ image_metas: {list(processed_image_metas.keys())}") | |
| generated_ids = model.generate(**generate_kwargs) | |
| else: | |
| print(f"๐ [DEBUG] LoRA ์ด๋ํฐ ์์ (๋ฉํฐ๋ชจ๋ฌ), ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ") | |
| # ๐ image_metas ํ๋ผ๋ฏธํฐ ์ถ๊ฐ (๊ณต์ ๋ฐฉ์) | |
| # ๐ ๋ฉํ๋ฐ์ดํฐ๋ฅผ ๊ณต์ ๊ตฌ์กฐ๋ก ๋ณํ (๋ชจ๋ธ ์๊ตฌ์ฌํญ) | |
| processed_image_metas = {} | |
| # ๐ ๊ณต์ ๋ฐฉ์: vision_grid_thw๋ฅผ ํ ์๋ก ๋ณํ | |
| if 'vision_grid_thw' in combined_image_metas: | |
| vision_grid = combined_image_metas['vision_grid_thw'] | |
| if isinstance(vision_grid, list): | |
| # ๐ Kanana ๋ชจ๋ธ ์๊ตฌ์ฌํญ: (T, H, W) ํํ์ 3์ฐจ์ ํ ์ | |
| if len(vision_grid) == 1 and len(vision_grid[0]) == 3: | |
| # [(1, 34, 52)] -> (1, 34, 52) ํ ์๋ก ๋ณํ | |
| t, h, w = vision_grid[0] | |
| # ๐ 3์ฐจ์ ํ ์๋ก ๋ณํ: (1, H, W) ํํ | |
| processed_image_metas['vision_grid_thw'] = torch.tensor([[t, h, w]], dtype=torch.long) | |
| print(f"๐ [DEBUG] vision_grid_thw ํ ์ ๋ณํ: {vision_grid} -> {processed_image_metas['vision_grid_thw'].shape}") | |
| else: | |
| # ๐ ๋ค๋ฅธ ํํ์ ๊ฒฝ์ฐ ์๋ณธ ์ ์ง | |
| processed_image_metas['vision_grid_thw'] = torch.tensor(vision_grid, dtype=torch.long) | |
| print(f"๐ [DEBUG] vision_grid_thw ํ ์ ๋ณํ (๊ธฐ๋ณธ): {vision_grid} -> {processed_image_metas['vision_grid_thw'].shape}") | |
| else: | |
| processed_image_metas['vision_grid_thw'] = vision_grid | |
| # ๐ ๋ค๋ฅธ ๋ฉํ๋ฐ์ดํฐ๋ ๊ทธ๋๋ก ์ ์ง | |
| for key, value in combined_image_metas.items(): | |
| if key != 'vision_grid_thw': | |
| processed_image_metas[key] = value | |
| generate_kwargs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask, | |
| 'pixel_values': pixel_values, | |
| 'image_metas': processed_image_metas, # ๐ ์ฒ๋ฆฌ๋ ์ด๋ฏธ์ง ๋ฉํ๋ฐ์ดํฐ | |
| **gen_config | |
| } | |
| print(f"๐ [DEBUG] ๊ธฐ๋ณธ ๋ชจ๋ธ ์์ฑ ํ๋ผ๋ฏธํฐ: {list(generate_kwargs.keys())}") | |
| print(f"๐ [DEBUG] ์ฒ๋ฆฌ๋ image_metas: {list(processed_image_metas.keys())}") | |
| generated_ids = model.generate(**generate_kwargs) | |
| # ํ ํฐ ์ค์ ์ ๋ช ์์ ์ผ๋ก ์ ๋ฌํ์ฌ EOS ํ ํฐ ๋ฌธ์ ํด๊ฒฐ | |
| # generate_kwargs = { | |
| # 'input_ids': input_ids.to(model.device), | |
| # 'attention_mask': attention_mask.to(model.device), | |
| # 'pixel_values': pixel_values.to(model.device), | |
| # 'max_new_tokens': gen_config['max_new_tokens'], | |
| # 'temperature': gen_config['temperature'], | |
| # 'top_p': gen_config['top_p'], | |
| # 'do_sample': gen_config['do_sample'], | |
| # 'repetition_penalty': gen_config.get('repetition_penalty', 1.0), | |
| # 'no_repeat_ngram_size': gen_config.get('no_repeat_ngram_size', 0), | |
| # # 'num_beams': gen_config.get('num_beams', 1), | |
| # 'use_cache': gen_config.get('use_cache', True), | |
| # 'max_time': gen_config.get('max_time', None), | |
| # 'early_stopping': gen_config.get('early_stopping', False), | |
| # 'stopping_criteria': gen_config.get('stopping_criteria', None), | |
| # } | |
| # | |
| # # ํ ํฐ ID ์ค์ (์ค์!) | |
| # if gen_config.get('eos_token_id') is not None: | |
| # generate_kwargs['eos_token_id'] = gen_config['eos_token_id'] | |
| # if gen_config.get('pad_token_id') is not None: | |
| # generate_kwargs['pad_token_id'] = gen_config['pad_token_id'] | |
| # if gen_config.get('bos_token_id') is not None: | |
| # generate_kwargs['bos_token_id'] = gen_config['bos_token_id'] | |
| # | |
| # print(f"๐ [DEBUG] ์ต์ข ์์ฑ ์ค์ : {generate_kwargs}") | |
| # | |
| # generated_ids = model.generate(**generate_kwargs) | |
| else: | |
| # ํ ์คํธ-only: ๊ธฐ์กด ๋ฐฉ์ | |
| print(f"๐ [DEBUG] ํ ์คํธ-only ์ถ๋ก ์คํ") | |
| print(f"๐ [DEBUG] ์์ฑ ์ค์ : {gen_config}") | |
| # ํ์์์ ์ค์ ์ ์ํ ์ถ๊ฐ ์ค์ (๋ ์ ์ ํ ๊ฐ์ผ๋ก ์กฐ์ ) | |
| # if 'max_time' not in gen_config: | |
| # gen_config['max_time'] = 60.0 # 60์ด ํ์์์์ผ๋ก ์กฐ์ | |
| # ์ถ๊ฐ ํ์์์ ์ค์ | |
| # gen_config['max_time'] = 60.0 # ๊ฐ์ 60์ด ํ์์์ | |
| # print(f"๐ [DEBUG] ๊ฐ์ ํ์์์ ์ค์ : {gen_config['max_time']}์ด") | |
| # ์ถ๊ฐ ์ฑ๋ฅ ์ต์ ํ ์ค์ | |
| gen_config['use_cache'] = True # ์บ์ ์ฌ์ฉ์ผ๋ก ์๋ ํฅ์ | |
| # PAD ํ ํฐ ์ค์ - ๋ชจ๋ธ ํ๋กํ ์ค์ ์ฐ์ | |
| if 'pad_token_id' not in gen_config: | |
| # ํ๋กํ์ ์ค์ ์ด ์์ ๋๋ง ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ | |
| if tokenizer.pad_token_id is not None: | |
| gen_config['pad_token_id'] = tokenizer.pad_token_id | |
| print(f"๐ [DEBUG] PAD ํ ํฐ ์ค์ : ํ ํฌ๋์ด์ ๊ธฐ๋ณธ๊ฐ ์ฌ์ฉ (ID: {tokenizer.pad_token_id})") | |
| else: | |
| gen_config['pad_token_id'] = None | |
| print(f"๐ [DEBUG] PAD ํ ํฐ ์ค์ : None (ํ ํฌ๋์ด์ ์ PAD ํ ํฐ ์์)") | |
| # ํ ํฐ ์ค์ - ํ๋กํ์์ ์ค์ ๋ ๊ฐ ์ฐ์ ์ฌ์ฉ | |
| if 'eos_token_id' not in gen_config or gen_config['eos_token_id'] is None: | |
| if tokenizer.eos_token_id is not None: | |
| gen_config['eos_token_id'] = tokenizer.eos_token_id | |
| print(f"๐ [DEBUG] EOS ํ ํฐ ์ค์ : {tokenizer.eos_token_id}") | |
| else: | |
| gen_config['eos_token_id'] = None | |
| print(f"๐ [DEBUG] EOS ํ ํฐ ์ค์ : None (์๋ ์ฒ๋ฆฌ)") | |
| if 'pad_token_id' not in gen_config or gen_config['pad_token_id'] is None: | |
| if tokenizer.pad_token_id is not None: | |
| gen_config['pad_token_id'] = tokenizer.pad_token_id | |
| else: | |
| gen_config['pad_token_id'] = None | |
| if 'bos_token_id' not in gen_config or gen_config['bos_token_id'] is None: | |
| if hasattr(tokenizer, 'bos_token_id') and tokenizer.bos_token_id is not None: | |
| gen_config['bos_token_id'] = tokenizer.bos_token_id | |
| else: | |
| gen_config['bos_token_id'] = None | |
| print(f"๐ [DEBUG] ์ต์ข ํ ํฐ ์ค์ : EOS={gen_config['eos_token_id']}, PAD={gen_config['pad_token_id']}, BOS={gen_config.get('bos_token_id')}") | |
| # ์์ฑ ์ค์ ์ต์ข ํ์ธ | |
| print(f"๐ [DEBUG] ์ต์ข ์์ฑ ์ค์ : {gen_config}") | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์์ - ํ ์คํธ๋ง") | |
| print(f"๐ [DEBUG] ์ต์ข ์ ๋ ฅ ํ ์ ๋๋ฐ์ด์ค: {input_ids.device}") | |
| print(f"๐ [DEBUG] ์ต์ข attention_mask ๋๋ฐ์ด์ค: {attention_mask.device}") | |
| # ๋ชจ๋ธ ์์ฑ ์งํ ์ํฉ ๋ชจ๋ํฐ๋ง์ ์ํ ์ฝ๋ฐฑ ์ถ๊ฐ | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์์ ์๊ฐ: {time.time()}") | |
| # LoRA ์ด๋ํฐ๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ์ธ์ง ํ์ธ | |
| if LORA_AVAILABLE and lora_manager and hasattr(lora_manager, 'current_adapter_name') and lora_manager.current_adapter_name: | |
| print(f"๐ [DEBUG] LoRA ์ด๋ํฐ ์ ์ฉ๋จ: {lora_manager.current_adapter_name}") | |
| # LoRA๊ฐ ์ ์ฉ๋ ๋ชจ๋ธ ์ฌ์ฉ | |
| lora_model = lora_manager.get_model() | |
| if lora_model: | |
| print(f"๐ [DEBUG] LoRA ๋ชจ๋ธ๋ก ์์ฑ ์คํ") | |
| # LoRA ๋ชจ๋ธ์ฉ ์ ๋ ฅ ์ฒ๋ฆฌ (token_type_ids ์ ๊ฑฐ) | |
| lora_inputs = { | |
| 'input_ids': input_ids, | |
| 'attention_mask': attention_mask | |
| } | |
| # token_type_ids๊ฐ ์๋ค๋ฉด ์ ๊ฑฐ | |
| # if 'token_type_ids' in locals() and token_type_ids is not None: | |
| # print(f"๐ [DEBUG] token_type_ids ์ ๊ฑฐ๋จ (LoRA ๋ชจ๋ธ ํธํ์ฑ)") | |
| generated_ids = lora_model.generate( | |
| **lora_inputs, | |
| **gen_config | |
| ) | |
| else: | |
| print(f"โ ๏ธ [DEBUG] LoRA ๋ชจ๋ธ์ ๊ฐ์ ธ์ฌ ์ ์์, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ") | |
| generated_ids = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| **gen_config | |
| ) | |
| else: | |
| print(f"๐ [DEBUG] LoRA ์ด๋ํฐ ์์, ๊ธฐ๋ณธ ๋ชจ๋ธ ์ฌ์ฉ") | |
| # LoRA ์ํ ๋๋ฒ๊น | |
| if LORA_AVAILABLE: | |
| if lora_manager: | |
| print(f"๐ [DEBUG] LoRA ๋งค๋์ ์กด์ฌ: {type(lora_manager)}") | |
| if hasattr(lora_manager, 'current_adapter_name'): | |
| print(f"๐ [DEBUG] ํ์ฌ ์ด๋ํฐ: {lora_manager.current_adapter_name}") | |
| if hasattr(lora_manager, 'base_model'): | |
| print(f"๐ [DEBUG] ๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋๋จ: {lora_manager.base_model is not None}") | |
| else: | |
| print(f"๐ [DEBUG] LoRA ๋งค๋์ ๊ฐ None") | |
| else: | |
| print(f"๐ [DEBUG] LoRA ์ง์ ์๋จ") | |
| generated_ids = model.generate( | |
| input_ids=input_ids, | |
| attention_mask=attention_mask, | |
| **gen_config | |
| ) | |
| # ํ ํฐ ์ค์ ์ ๋ช ์์ ์ผ๋ก ์ ๋ฌํ์ฌ EOS ํ ํฐ ๋ฌธ์ ํด๊ฒฐ | |
| # generate_kwargs = { | |
| # 'input_ids': input_ids.to(model.device), | |
| # 'attention_mask': attention_mask.to(model.device), | |
| # 'max_new_tokens': gen_config['max_new_tokens'], | |
| # 'temperature': gen_config['temperature'], | |
| # 'top_p': gen_config['top_p'], | |
| # 'do_sample': gen_config['do_sample'], | |
| # 'repetition_penalty': gen_config.get('repetition_penalty', 1.0), | |
| # 'no_repeat_ngram_size': gen_config.get('no_repeat_ngram_size', 0), | |
| # # 'num_beams': gen_config.get('num_beams', 1), | |
| # 'use_cache': gen_config.get('use_cache', True), | |
| # 'max_time': gen_config.get('max_time', None), | |
| # 'early_stopping': gen_config.get('early_stopping', False), | |
| # 'stopping_criteria': gen_config.get('stopping_criteria', None), | |
| # } | |
| # | |
| # # ํ ํฐ ID ์ค์ (์ค์!) | |
| # if gen_config.get('eos_token_id') is not None: | |
| # generate_kwargs['eos_token_id'] = gen_config['eos_token_id'] | |
| # if gen_config.get('pad_token_id') is not None: | |
| # generate_kwargs['pad_token_id'] = gen_config['pad_token_id'] | |
| # if gen_config.get('bos_token_id') is not None: | |
| # generate_kwargs['bos_token_id'] = gen_config['bos_token_id'] | |
| # print(f"๐ [DEBUG] ์ต์ข ์์ฑ ์ค์ : {generate_kwargs}") | |
| # generated_ids = model.generate(**generate_kwargs) | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์์ฑ ์๋ฃ ์๊ฐ: {time.time()}") | |
| t_gen_end = time.time() | |
| print(f"๐ [DEBUG] ๋ชจ๋ธ ์ถ๋ก ์๋ฃ - ์์์๊ฐ: {t_gen_end - t_gen_start:.3f}์ด") | |
| print(f"๐ [DEBUG] ์์ฑ๋ ํ ํฐ ์: {generated_ids.shape[1] - input_ids.shape[1]}") | |
| print(f"๐ [DEBUG] ์ต์ข generated_ids shape: {generated_ids.shape}") | |
| print(f"๐ [DEBUG] ์ต์ข generated_ids ๋๋ฐ์ด์ค: {generated_ids.device}") | |
| print(f"๐ [DEBUG] ์ต์ข generated_ids dtype: {generated_ids.dtype}") | |
| except Exception as e: | |
| print(f"โ [DEBUG] ๋ชจ๋ธ ์ถ๋ก ์ค ์๋ฌ ๋ฐ์: {str(e)}") | |
| print(f"โ [DEBUG] ์๋ฌ ํ์ : {type(e).__name__}") | |
| print(f"โ [DEBUG] ์๋ฌ ์์ธ: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Generation failed: {str(e)}"} | |
| # --- 6. ์๋ต ์ถ์ถ --- | |
| print(f"๐ [DEBUG] ์๋ต ์ถ์ถ ์์") | |
| t_decode_start = time.time() | |
| try: | |
| # ์์ฑ๋ ํ ์คํธ ๋์ฝ๋ฉ | |
| full_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True) | |
| print(f"๐ [DEBUG] ์ ์ฒด ํ ์คํธ ๊ธธ์ด: {len(full_text)}") | |
| print(f"๐ [DEBUG] ์ ์ฒด ์์ฑ ํ ์คํธ (Raw): \n---\n{full_text}\n---") | |
| print(f"๐ [DEBUG] ์ฌ์ฉ๋ ํ๋กฌํํธ: {formatted_prompt}") | |
| # ํ๋กํ๋ณ ์๋ต ์ถ์ถ (์์ ํ ๋ฐฉ์) | |
| if hasattr(current_profile, 'extract_response'): | |
| try: | |
| response = current_profile.extract_response(full_text, formatted_prompt) | |
| print(f"๐ [DEBUG] ํ๋กํ extract_response ์ฌ์ฉ ์ฑ๊ณต") | |
| except Exception as extract_error: | |
| print(f"โ ๏ธ [DEBUG] ํ๋กํ extract_response ์คํจ: {extract_error}") | |
| # ํด๋ฐฑ: ๊ธฐ๋ณธ ์๋ต ์ถ์ถ | |
| response = full_text.replace(formatted_prompt, "").strip() if formatted_prompt else full_text | |
| print(f"๐ [DEBUG] ๊ธฐ๋ณธ ์๋ต ์ถ์ถ ์ฌ์ฉ (ํด๋ฐฑ)") | |
| else: | |
| # ๊ธฐ๋ณธ ์๋ต ์ถ์ถ | |
| response = full_text.replace(formatted_prompt, "").strip() if formatted_prompt else full_text | |
| print(f"๐ [DEBUG] ๊ธฐ๋ณธ ์๋ต ์ถ์ถ ์ฌ์ฉ") | |
| print(f"๐ [DEBUG] ์ถ์ถ๋ ์๋ต ๊ธธ์ด: {len(response)}") | |
| print(f"๐ [DEBUG] ์ต์ข ์๋ต: {response}") | |
| t_decode_end = time.time() | |
| print(f"๐ [DEBUG] ์๋ต ์ถ์ถ ์๋ฃ - ์์์๊ฐ: {t_decode_end - t_decode_start:.3f}์ด") | |
| except Exception as e: | |
| print(f"โ [DEBUG] ์๋ต ์ถ์ถ ์ค ์๋ฌ ๋ฐ์: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": f"Response extraction failed: {str(e)}"} | |
| # --- 7. ๊ฒฐ๊ณผ ๋ฐํ --- | |
| total_time = time.time() - t_tok_start | |
| print(f"๐ [DEBUG] ์ ์ฒด ์ฒ๋ฆฌ ์๋ฃ - ์ด ์์์๊ฐ: {total_time:.3f}์ด") | |
| # ๐ ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์๋ฃ (์ ์ญ ๋ณ์ ์ด๊ธฐํ๋ ์ ๊ฑฐ๋จ) | |
| return { | |
| "generated_text": response, | |
| "processing_time": total_time, | |
| "model_name": current_profile.display_name, | |
| "image_processed": image_processed, | |
| "tokens_generated": generated_ids.shape[1] - input_ids.shape[1], | |
| "total_tokens": generated_ids.shape[1] | |
| } | |
| except Exception as e: | |
| print(f"โ [DEBUG] generate_sync ์ ์ฒด ์๋ฌ: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return {"error": str(e)} | |
| async def get_lora_status(): | |
| """ํ์ฌ LoRA ์ํ ํ์ธ""" | |
| try: | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return {"status": "error", "message": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค"} | |
| return { | |
| "status": "success", | |
| "lora_available": True, | |
| "current_adapter": lora_manager.current_adapter_name if hasattr(lora_manager, 'current_adapter_name') else None, | |
| "base_model_loaded": hasattr(lora_manager, 'base_model') and lora_manager.base_model is not None, | |
| "device": getattr(lora_manager, 'device', 'unknown') | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_context_status(): | |
| """์ปจํ ์คํธ ๊ด๋ฆฌ์ ์ํ ํ์ธ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| # ์ธ์ ๋ณ ์ ๋ณด ์์ง | |
| session_info = {} | |
| for session_id, conversation in context_manager.session_conversations.items(): | |
| session_info[session_id] = { | |
| "turns": len(conversation), | |
| "user_messages": len([t for t in conversation if t.role == "user"]), | |
| "assistant_messages": len([t for t in conversation if t.role == "assistant"]) | |
| } | |
| return { | |
| "status": "success", | |
| "context_manager_available": True, | |
| "total_sessions": len(context_manager.session_conversations), | |
| "sessions": session_info, | |
| "max_tokens": context_manager.max_tokens, | |
| "max_turns": context_manager.max_turns, | |
| "strategy": context_manager.strategy | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_context_history(session_id: str = None): | |
| """์ปจํ ์คํธ ํ์คํ ๋ฆฌ ์กฐํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| if session_id: | |
| # ํน์ ์ธ์ ์ ์ปจํ ์คํธ๋ง ์กฐํ | |
| context = context_manager.get_context(include_system=True, max_length=4000, session_id=session_id) | |
| session_summary = context_manager.get_context_summary(session_id) | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "context": context, | |
| "history_length": session_summary.get("total_turns", 0), | |
| "session_summary": session_summary | |
| } | |
| else: | |
| # ์ ์ฒด ์ปจํ ์คํธ ์กฐํ | |
| context = context_manager.get_context(include_system=True, max_length=4000) | |
| return { | |
| "status": "success", | |
| "context": context, | |
| "history_length": len(context_manager.conversation_history), | |
| "all_sessions": True | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_auto_cleanup_config(): | |
| """์๋ ์ ๋ฆฌ ์ค์ ์กฐํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| config = context_manager.get_auto_cleanup_config() | |
| return { | |
| "status": "success", | |
| "auto_cleanup_config": config | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def set_auto_cleanup_config( | |
| enabled: bool = Form(True), | |
| interval_turns: int = Form(8), | |
| interval_time: int = Form(300), | |
| strategy: str = Form("smart") | |
| ): | |
| """์๋ ์ ๋ฆฌ ์ค์ ๋ณ๊ฒฝ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| context_manager.set_auto_cleanup_config( | |
| enabled=enabled, | |
| interval_turns=interval_turns, | |
| interval_time=interval_time, | |
| strategy=strategy | |
| ) | |
| return { | |
| "status": "success", | |
| "message": "์๋ ์ ๋ฆฌ ์ค์ ์ด ์ ๋ฐ์ดํธ๋์์ต๋๋ค", | |
| "new_config": context_manager.get_auto_cleanup_config() | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def manual_cleanup_session(session_id: str): | |
| """ํน์ ์ธ์ ์๋ ์ ๋ฆฌ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| # ์๋ ์ ๋ฆฌ ์คํ | |
| context_manager._execute_auto_cleanup(session_id) | |
| return { | |
| "status": "success", | |
| "message": f"์ธ์ {session_id} ์๋ ์ ๋ฆฌ ์๋ฃ", | |
| "session_id": session_id | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def manual_cleanup_all_sessions(): | |
| """๋ชจ๋ ์ธ์ ์๋ ์ ๋ฆฌ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| # ๋ชจ๋ ์ธ์ ์ ๋ํด ์๋ ์ ๋ฆฌ ์คํ | |
| for session_id in context_manager.session_conversations.keys(): | |
| context_manager._execute_auto_cleanup(session_id) | |
| return { | |
| "status": "success", | |
| "message": "๋ชจ๋ ์ธ์ ์๋ ์ ๋ฆฌ ์๋ฃ" | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def generate(request: Request, | |
| prompt: str = Form(...), | |
| image1: UploadFile = File(None), | |
| image2: UploadFile = File(None), | |
| image3: UploadFile = File(None), | |
| image4: UploadFile = File(None), | |
| user_id: str = Form("anonymous"), | |
| room_id: str = Form("default"), | |
| use_context: bool = Form(True), | |
| session_id: str = Form(None)): | |
| if not model_loaded: | |
| raise HTTPException(status_code=503, detail="๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค.") | |
| start_time = time.time() | |
| # ์ธ์ ID๊ฐ ์์ผ๋ฉด ์๋ ์์ฑ (์ฑํ ๋ฐฉ๋ณ ๊ณ ์ ์ธ์ ) | |
| if not session_id: | |
| # ์ฑํ ๋ฐฉ + ์ฌ์ฉ์ + ํ์์คํฌํ ๊ธฐ๋ฐ์ผ๋ก ๊ณ ์ ํ ์ธ์ ์์ฑ | |
| timestamp = int(time.time()) | |
| session_id = f"room_{room_id}_user_{user_id}_{timestamp}" | |
| print(f"๐ [DEBUG] ์๋ ์ธ์ ID ์์ฑ: {session_id} (์ฑํ ๋ฐฉ: {room_id}, ์ฌ์ฉ์: {user_id})") | |
| if use_context: | |
| context_manager.add_user_message(prompt, metadata={"session_id": session_id}) | |
| print(f"๐ [DEBUG] ์ฌ์ฉ์ ๋ฉ์์ง ์ถ๊ฐ๋จ (์ธ์ : {session_id})") | |
| # ์ด๋ฏธ์ง ๋ฐ์ดํฐ ์ฒ๋ฆฌ | |
| image_data_list = [] | |
| for img_file in [image1, image2, image3, image4]: | |
| if img_file: | |
| try: | |
| data = await img_file.read() | |
| image_data_list.append(data) | |
| except Exception as e: | |
| logger.warning(f"์ด๋ฏธ์ง ๋ก๋ ์คํจ: {e}") | |
| try: | |
| # generate_sync ํจ์ ํธ์ถ (์ปจํ ์คํธ ํฌํจ) | |
| result = generate_sync(prompt, image_data_list, use_context=use_context, session_id=session_id, user_id=user_id, room_id=room_id) | |
| if "error" in result: | |
| raise HTTPException(status_code=500, detail=result["error"]) | |
| if use_context: | |
| context_manager.add_assistant_message(result["generated_text"], metadata={"session_id": session_id}) | |
| return GenerateResponse( | |
| generated_text=result["generated_text"], | |
| processing_time=result["processing_time"], | |
| model_name=result["model_name"], | |
| image_processed=result["image_processed"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"โ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {e}", exc_info=True) | |
| raise HTTPException(status_code=500, detail=f"๋ชจ๋ธ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}") | |
| async def generate_multimodal(prompt: str = Form(...), | |
| image: UploadFile = File(None), | |
| model_id: Optional[str] = Form(None), | |
| max_length: Optional[int] = Form(None), | |
| temperature: Optional[float] = Form(None), | |
| top_p: Optional[float] = Form(None), | |
| do_sample: Optional[bool] = Form(None)): | |
| global model_loaded, current_profile, model, tokenizer, processor | |
| if not model_loaded: | |
| raise HTTPException(status_code=500, detail="๋ชจ๋ธ์ด ๋ก๋๋์ง ์์์ต๋๋ค") | |
| start_time = time.time() | |
| pil_image = None | |
| if image: | |
| try: | |
| data = await image.read() | |
| pil_image = Image.open(io.BytesIO(data)).convert("RGB") | |
| except Exception as e: | |
| logger.error(f"์ด๋ฏธ์ง ์ฒ๋ฆฌ ์คํจ: {e}") | |
| try: | |
| image_list = [pil_image] if pil_image else [] | |
| image_tokens = " ".join(["<image>"] * len(image_list)) if image_list else "" | |
| conv = [] | |
| if image_list: | |
| conv.append({"role": "user", "content": image_tokens}) | |
| conv.append({"role": "user", "content": prompt}) | |
| logger.info("=== STEP 1: building sample ===") | |
| sample = {"image": [], "conv": [{"role": "user", "content": prompt}]} | |
| logger.info("=== STEP 2: calling processor ===") | |
| inputs = processor.batch_encode_collate([sample], padding_side='left', add_generation_prompt=True) | |
| logger.info("=== STEP 3: processor returned ===") | |
| for k, v in inputs.items(): | |
| if isinstance(v, torch.Tensor): | |
| logger.info(f"Key {k}: tensor shape {v.shape}, dtype {v.dtype}, device {v.device}") | |
| else: | |
| logger.info(f"Key {k}: {type(v)}") | |
| logger.info("=== STEP 4: moving to device ===") | |
| inputs = {k: (v.to(model.device) if isinstance(v, torch.Tensor) else v) for k, v in inputs.items()} | |
| logger.info("=== STEP 5: moved to device ===") | |
| eot_id = tokenizer.convert_tokens_to_ids("<|eot_id|>") | |
| # Manual greedy decoding loop | |
| generated = inputs["input_ids"].clone() | |
| for _ in range(64): | |
| # ๐ torch import ๋ฌธ์ ํด๊ฒฐ | |
| import torch | |
| with torch.no_grad(): | |
| out = model(**inputs) | |
| next_token = out.logits[:, -1, :].argmax(dim=-1, keepdim=True) | |
| generated = torch.cat([generated, next_token], dim=-1) | |
| logger.info(f"Step token: {next_token.item()}") | |
| if next_token.item() == eot_id: | |
| break | |
| inputs["input_ids"] = generated | |
| logger.info(f"Final Generated IDs: {generated[0].tolist()}") | |
| generated_text = tokenizer.decode(generated[0], skip_special_tokens=True) | |
| if "<|im_start|>assistant" in generated_text: | |
| response = generated_text.split("<|im_start|>assistant")[-1].split("<|im_end|>")[0].strip() | |
| else: | |
| response = generated_text.strip() | |
| processing_time = time.time() - start_time | |
| return MultimodalGenerateResponse(generated_text=response, | |
| processing_time=processing_time, | |
| model_name=current_profile.display_name, | |
| model_id=model_id or current_profile.get_model_info().get("model_name"), | |
| image_processed=bool(pil_image)) | |
| except Exception as e: | |
| logger.error(f"โ ๋ฉํฐ๋ชจ๋ฌ ์์ฑ ์ค๋ฅ: {e}") | |
| raise HTTPException(status_code=500, detail=f"๋ฉํฐ๋ชจ๋ฌ ์์ฑ ์คํจ: {str(e)}") | |
| async def list_models(): | |
| """์ฌ์ฉ ๊ฐ๋ฅํ ๋ชจ๋ธ ๋ชฉ๋ก""" | |
| return { | |
| "models": list_available_models(), | |
| "current_model": current_profile.get_model_info() if current_profile else None | |
| } | |
| async def switch_model(model_id: str): | |
| """๋ชจ๋ธ ๋ณ๊ฒฝ""" | |
| try: | |
| await load_model_async(model_id) | |
| return { | |
| "message": f"๋ชจ๋ธ ๋ณ๊ฒฝ ์ฑ๊ณต: {model_id}", | |
| "current_model": current_profile.display_name | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"๋ชจ๋ธ ๋ณ๊ฒฝ ์คํจ: {str(e)}") | |
| async def root(): | |
| """๋ฃจํธ ์๋ํฌ์ธํธ""" | |
| return { | |
| "message": "Lily LLM API v2 ์๋ฒ", | |
| "version": "2.0.0", | |
| "current_model": current_profile.display_name if current_profile else "None", | |
| "docs": "/docs" | |
| } | |
| async def health_check(): | |
| """ํฌ์ค ์ฒดํฌ ์๋ํฌ์ธํธ""" | |
| available_models = list_available_models() | |
| return HealthResponse( | |
| status="healthy", | |
| model_loaded=model_loaded, | |
| current_model=current_profile.display_name if current_profile else "None", | |
| available_models=available_models | |
| ) | |
| async def upload_document( | |
| file: UploadFile = File(...), | |
| user_id: str = Form("default_user"), # ๊ธฐ๋ณธ ์ฌ์ฉ์ ID | |
| room_id: str = Form("default"), # ์ฑํ ๋ฐฉ ID | |
| document_id: Optional[str] = Form(None) # ๋ฌธ์ ID (์๋ ์์ฑ ๊ฐ๋ฅ) | |
| ): | |
| """๋ฌธ์ ์ ๋ก๋ ๋ฐ RAG ์ฒ๋ฆฌ""" | |
| start_time = time.time() | |
| # document_id = None | |
| try: | |
| # ๋ฌธ์ ID ์์ฑ (์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ) | |
| if not document_id: | |
| import uuid | |
| document_id = str(uuid.uuid4())[:8] | |
| # ์์ ํ์ผ ์ ์ฅ | |
| temp_file_path = f"./temp_{document_id}_{file.filename}" | |
| with open(temp_file_path, "wb") as f: | |
| content = await file.read() | |
| f.write(content) | |
| # ๋ฌธ์ ์ฒ๋ฆฌ ๋ฐ ๋ฒกํฐ ์คํ ์ด์ ์ ์ฅ | |
| result = rag_processor.process_and_store_document( | |
| user_id, document_id, temp_file_path | |
| ) | |
| # ์์ ํ์ผ ์ญ์ | |
| import os | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| processing_time = time.time() - start_time | |
| logger.info(f"๐ ๋ฌธ์ ์ ๋ก๋ ์๋ฃ ({processing_time:.2f}์ด): {file.filename}") | |
| # ์๋ก์ด ๋ฉ๋ชจ๋ฆฌ ์์คํ ์ ๋ฌธ์ ์ ๋ณด ์ถ๊ฐ | |
| if result["success"]: | |
| try: | |
| # ๋ฌธ์ ์ ๋ณด๋ฅผ ์ฑํ ๋ฐฉ ์ปจํ ์คํธ์ ์ถ๊ฐ | |
| chunks = result.get("chunks", []) | |
| chunk_count = len(chunks) if isinstance(chunks, list) else 0 | |
| document_info = { | |
| "document_id": document_id, | |
| "filename": file.filename, | |
| "uploaded_by": user_id, | |
| "document_type": file.filename.split('.')[-1].lower() if '.' in file.filename else "unknown", | |
| "page_count": result.get("page_count", 0), | |
| "chunk_count": chunk_count, | |
| "summary": result.get("message", "") | |
| } | |
| # ํตํฉ ๋ฉ๋ชจ๋ฆฌ ๊ด๋ฆฌ์์ ๋ฌธ์ ์ถ๊ฐ | |
| integrated_memory_manager.add_document_to_room(room_id, document_info) | |
| # ์ฌ์ฉ์ ํต๊ณ ์ ๋ฐ์ดํธ | |
| integrated_memory_manager.record_conversation( | |
| user_id, room_id, | |
| topic=f"๋ฌธ์ ์ ๋ก๋: {file.filename}" | |
| ) | |
| logger.info(f"โ ๋ฉ๋ชจ๋ฆฌ ์์คํ ์ ๋ฌธ์ ์ ๋ณด ์ถ๊ฐ ์๋ฃ: {room_id} - {file.filename}") | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ ๋ฉ๋ชจ๋ฆฌ ์์คํ ์ ๋ฐ์ดํธ ์คํจ: {e}") | |
| # ๋ฌธ์ ์ ๋ก๋ ํ ์๋ AI ์๋ต ์์ฑ ๋นํ์ฑํ (AI ๋ฆฌ์์ค ์ ์ฝ) | |
| # ์ฌ์ฉ์๊ฐ ์ง์ ์ง๋ฌธํ ๋๋ง AI ์๋ต ์์ฑ | |
| auto_generate_response = False | |
| if result["success"]: | |
| # ์๋ AI ์์ฝ ์์ด ๋ฌธ์ ์ ๋ก๋๋ง ์๋ฃ | |
| result["auto_response"] = f"๋ฌธ์ '{file.filename}' ์ ๋ก๋ ์๋ฃ! ์ด์ ์ง๋ฌธํด์ฃผ์ธ์." | |
| logger.info(f"๐ ์๋ AI ์๋ต ์์ฑ ๊ฑด๋๋ฐ๊ธฐ - AI ๋ฆฌ์์ค ์ ์ฝ (์ฌ์ฉ์ ์ง๋ฌธ ์์๋ง AI ์๋ต)") | |
| else: | |
| result["auto_response"] = "๋ฌธ์ ์ ๋ก๋์ ์คํจํ์ต๋๋ค." | |
| return DocumentUploadResponse( | |
| success=result["success"], | |
| document_id=document_id, | |
| message=result.get("message", ""), | |
| chunks=result.get("chunks"), | |
| latex_count=result.get("latex_count"), | |
| error=result.get("error"), | |
| auto_response=result.get("auto_response", "") # ์๋ ์๋ต ์ถ๊ฐ | |
| ) | |
| except Exception as e: | |
| logger.error(f"โ ๋ฌธ์ ์ ๋ก๋ ์คํจ: {e}") | |
| return DocumentUploadResponse( | |
| success=False, | |
| document_id=document_id if 'document_id' in locals() else "unknown", | |
| message="๋ฌธ์ ์ ๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค.", | |
| error=str(e) | |
| ) | |
| async def summarize_conversation( | |
| room_id: str = Form("default"), | |
| user_id: str = Form("anonymous"), | |
| max_length: int = Form(300) | |
| ): | |
| """๐ summarizers๋ฅผ ํ์ฉํ ๋ํ ์์ฝ ์์ฑ""" | |
| try: | |
| if not text_summarizer.is_available(): | |
| return { | |
| "success": False, | |
| "message": "summarizers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค." | |
| } | |
| # ์ค๋งํธ ๋ํ ์์ฝ ์์ฑ | |
| success = integrated_memory_manager.create_smart_conversation_summary( | |
| room_id, max_length | |
| ) | |
| if success: | |
| # ์ ๋ฐ์ดํธ๋ ์ปจํ ์คํธ ์กฐํ | |
| room_context = integrated_memory_manager.room_context_manager.get_room_context(room_id) | |
| return { | |
| "success": True, | |
| "message": "๋ํ ์์ฝ ์์ฑ ์๋ฃ", | |
| "summary": room_context.conversation_summary if room_context else "", | |
| "key_topics": room_context.key_topics if room_context else [], | |
| "room_id": room_id | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "message": "๋ํ ์์ฝ ์์ฑ ์คํจ" | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ๋ํ ์์ฝ ์์ฑ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "message": f"๋ํ ์์ฝ ์์ฑ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" | |
| } | |
| async def summarize_text( | |
| text: str = Form(...), | |
| max_length: int = Form(200), | |
| model_name: str = Form("kobart") | |
| ): | |
| """๐ summarizers๋ฅผ ํ์ฉํ ํ ์คํธ ์์ฝ""" | |
| try: | |
| if not text_summarizer.is_available(): | |
| return { | |
| "success": False, | |
| "message": "summarizers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค." | |
| } | |
| if not text or len(text.strip()) < 50: | |
| return { | |
| "success": False, | |
| "message": "์์ฝํ ํ ์คํธ๊ฐ ๋๋ฌด ์งง์ต๋๋ค (์ต์ 50์ ํ์)" | |
| } | |
| # ์์ฝ ์ค์ | |
| config = SummaryConfig( | |
| max_length=max_length, | |
| min_length=max_length // 2, | |
| do_sample=False, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| # ์์ฝ ์ํ | |
| summary = text_summarizer.summarize_text(text, model_name, config) | |
| if summary: | |
| compression_ratio = len(summary) / len(text) | |
| return { | |
| "success": True, | |
| "message": "ํ ์คํธ ์์ฝ ์๋ฃ", | |
| "original_length": len(text), | |
| "summary_length": len(summary), | |
| "compression_ratio": round(compression_ratio, 2), | |
| "summary": summary, | |
| "model_used": model_name | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "message": "์์ฝ ์์ฑ ์คํจ" | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ํ ์คํธ ์์ฝ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "message": f"ํ ์คํธ ์์ฝ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" | |
| } | |
| async def compress_context( | |
| room_id: str = Form("default"), | |
| target_length: int = Form(800) | |
| ): | |
| """๐ ์ฑํ ๋ฐฉ ์ปจํ ์คํธ ์์ถ""" | |
| try: | |
| if not text_summarizer.is_available(): | |
| return { | |
| "success": False, | |
| "message": "summarizers ๋ผ์ด๋ธ๋ฌ๋ฆฌ๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค." | |
| } | |
| # ์ปจํ ์คํธ ์์ถ ์ํ | |
| success = integrated_memory_manager.compress_room_context(room_id, target_length) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "์ปจํ ์คํธ ์์ถ ์๋ฃ", | |
| "room_id": room_id, | |
| "target_length": target_length | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "message": "์ปจํ ์คํธ ์์ถ ์คํจ" | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ปจํ ์คํธ ์์ถ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "message": f"์ปจํ ์คํธ ์์ถ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" | |
| } | |
| async def get_summarizer_status(): | |
| """๐ summarizers ๋ผ์ด๋ธ๋ฌ๋ฆฌ ์ํ ํ์ธ""" | |
| try: | |
| available = text_summarizer.is_available() | |
| models = text_summarizer.get_available_models() if available else [] | |
| return { | |
| "success": True, | |
| "summarizers_available": available, | |
| "available_models": models, | |
| "default_model": "hyunwoongko/kobart" if available else None | |
| } | |
| except Exception as e: | |
| logger.error(f"โ summarizer ์ํ ํ์ธ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "message": f"์ํ ํ์ธ ์ค ์ค๋ฅ ๋ฐ์: {str(e)}" | |
| } | |
| async def generate_rag_response( | |
| query: str = Form(...), | |
| user_id: str = Form("default_user"), | |
| document_id: str = Form(...), | |
| max_length: Optional[int] = Form(None), | |
| temperature: Optional[float] = Form(None), | |
| top_p: Optional[float] = Form(None), | |
| do_sample: Optional[bool] = Form(None) | |
| ): | |
| """RAG ๊ธฐ๋ฐ ์๋ต ์์ฑ""" | |
| start_time = time.time() | |
| try: | |
| # ๋ชจ๋ธ์ด ๋ก๋๋์๋์ง ํ์ธ | |
| llm_model = None | |
| if model is not None and hasattr(model, 'generate_text'): | |
| llm_model = model | |
| logger.info("โ ๋ก๋๋ ๋ชจ๋ธ์ ์ฌ์ฉํ์ฌ RAG ์๋ต ์์ฑ") | |
| else: | |
| logger.warning("โ ๏ธ ๋ชจ๋ธ์ด ๋ก๋๋์ง ์์ ํ ์คํธ ๊ธฐ๋ฐ ์๋ต๋ง ์์ฑ") | |
| # RAG ์๋ต ์์ฑ | |
| result = rag_processor.generate_rag_response( | |
| user_id, document_id, query, llm_model=llm_model | |
| ) | |
| processing_time = time.time() - start_time | |
| logger.info(f"๐ RAG ์๋ต ์์ฑ ์๋ฃ ({processing_time:.2f}์ด)") | |
| return RAGResponse( | |
| success=result["success"], | |
| response=result["response"], | |
| context=result["context"], | |
| sources=result["sources"], | |
| search_results=result["search_results"], | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"โ RAG ์๋ต ์์ฑ ์คํจ: {e}") | |
| return RAGResponse( | |
| success=False, | |
| response=f"RAG ์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}", | |
| context="", | |
| sources=[], | |
| search_results=0, | |
| processing_time=0.0 | |
| ) | |
| async def generate_hybrid_rag_response( | |
| query: str = Form(...), | |
| user_id: str = Form("default_user"), | |
| document_id: str = Form(...), | |
| image1: UploadFile = File(None), | |
| image2: UploadFile = File(None), | |
| image3: UploadFile = File(None), | |
| image4: UploadFile = File(None), | |
| image5: UploadFile = File(None), | |
| max_length: Optional[int] = Form(None), | |
| temperature: Optional[float] = Form(None), | |
| top_p: Optional[float] = Form(None), | |
| do_sample: Optional[bool] = Form(None) | |
| ): | |
| """ํ์ด๋ธ๋ฆฌ๋ RAG ๊ธฐ๋ฐ ์๋ต ์์ฑ (์ด๋ฏธ์ง + ๋ฌธ์)""" | |
| start_time = time.time() | |
| try: | |
| # ์ด๋ฏธ์ง ํ์ผ ์ฒ๋ฆฌ | |
| image_files = [] | |
| uploaded_images = [image1, image2, image3, image4, image5] | |
| for i, img in enumerate(uploaded_images): | |
| if img: | |
| try: | |
| # ์์ ํ์ผ๋ก ์ ์ฅ | |
| import tempfile | |
| with tempfile.NamedTemporaryFile( | |
| suffix=f"_{i}.png", | |
| delete=False, | |
| prefix="hybrid_image_" | |
| ) as temp_file: | |
| image_data = await img.read() | |
| temp_file.write(image_data) | |
| image_files.append(temp_file.name) | |
| logger.info(f"๐ธ ์ด๋ฏธ์ง ์ ๋ก๋: {img.filename} -> {temp_file.name}") | |
| except Exception as e: | |
| logger.error(f"โ ์ด๋ฏธ์ง ์ฒ๋ฆฌ ์คํจ: {e}") | |
| # RAG ์๋ต ์์ฑ (์ด๋ฏธ์ง ํฌํจ) | |
| result = rag_processor.generate_rag_response( | |
| user_id, document_id, query, | |
| llm_model=model, # ์ค์ ๋ชจ๋ธ ์ธ์คํด์ค ์ฌ์ฉ | |
| image_files=image_files if image_files else None | |
| ) | |
| # ์์ ์ด๋ฏธ์ง ํ์ผ ์ ๋ฆฌ | |
| for temp_file in image_files: | |
| try: | |
| if os.path.exists(temp_file): | |
| os.remove(temp_file) | |
| logger.info(f"๐๏ธ ์์ ์ด๋ฏธ์ง ํ์ผ ์ญ์ : {temp_file}") | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ ์์ ํ์ผ ์ญ์ ์คํจ: {e}") | |
| processing_time = time.time() - start_time | |
| logger.info(f"๐ ํ์ด๋ธ๋ฆฌ๋ RAG ์๋ต ์์ฑ ์๋ฃ ({processing_time:.2f}์ด)") | |
| return RAGResponse( | |
| success=result["success"], | |
| response=result["response"], | |
| context=result["context"], | |
| sources=result["sources"], | |
| search_results=result["search_results"], | |
| processing_time=processing_time | |
| ) | |
| except Exception as e: | |
| logger.error(f"โ ํ์ด๋ธ๋ฆฌ๋ RAG ์๋ต ์์ฑ ์คํจ: {e}") | |
| return RAGResponse( | |
| success=False, | |
| response=f"์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}", | |
| context="", | |
| sources=[], | |
| search_results=0, | |
| processing_time=time.time() - start_time | |
| ) | |
| async def list_user_documents(user_id: str): | |
| """์ฌ์ฉ์์ ๋ฌธ์ ๋ชฉ๋ก ์กฐํ""" | |
| try: | |
| from lily_llm_core.vector_store_manager import vector_store_manager | |
| documents = vector_store_manager.get_all_documents(user_id) | |
| return documents | |
| except Exception as e: | |
| logger.error(f"โ ๋ฌธ์ ๋ชฉ๋ก ์กฐํ ์คํจ: {e}") | |
| return {"documents": [], "total_docs": 0, "error": str(e)} | |
| async def delete_document(user_id: str, document_id: str): | |
| """๋ฌธ์ ์ญ์ """ | |
| try: | |
| result = rag_processor.delete_document(user_id, document_id) | |
| return result | |
| except Exception as e: | |
| logger.error(f"โ ๋ฌธ์ ์ญ์ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| # ์ฌ์ฉ์ ๊ด๋ฆฌ ์๋ํฌ์ธํธ | |
| async def create_user( | |
| user_id: str = Form(...), | |
| username: Optional[str] = Form(None), | |
| email: Optional[str] = Form(None) | |
| ): | |
| """์ฌ์ฉ์ ์์ฑ""" | |
| try: | |
| success = db_manager.add_user(user_id, username, email) | |
| if success: | |
| user_info = db_manager.get_user(user_id) | |
| return UserResponse( | |
| success=True, | |
| user_id=user_id, | |
| username=user_info.get('username') if user_info else None, | |
| email=user_info.get('email') if user_info else None, | |
| created_at=user_info.get('created_at') if user_info else None | |
| ) | |
| else: | |
| return UserResponse(success=False, user_id=user_id, error="์ฌ์ฉ์ ์์ฑ ์คํจ") | |
| except Exception as e: | |
| logger.error(f"โ ์ฌ์ฉ์ ์์ฑ ์ค๋ฅ: {e}") | |
| return UserResponse(success=False, user_id=user_id, error=str(e)) | |
| async def get_user_info(user_id: str): | |
| """์ฌ์ฉ์ ์ ๋ณด ์กฐํ""" | |
| try: | |
| user_info = db_manager.get_user(user_id) | |
| if user_info: | |
| return UserResponse( | |
| success=True, | |
| user_id=user_id, | |
| username=user_info.get('username'), | |
| email=user_info.get('email'), | |
| created_at=user_info.get('created_at') | |
| ) | |
| else: | |
| return UserResponse(success=False, user_id=user_id, error="์ฌ์ฉ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค") | |
| except Exception as e: | |
| logger.error(f"โ ์ฌ์ฉ์ ์กฐํ ์ค๋ฅ: {e}") | |
| return UserResponse(success=False, user_id=user_id, error=str(e)) | |
| # ์ธ์ ๊ด๋ฆฌ ์๋ํฌ์ธํธ | |
| async def create_session( | |
| user_id: str = Form(...), | |
| session_name: Optional[str] = Form(None) | |
| ): | |
| """์ฑํ ์ธ์ ์์ฑ""" | |
| try: | |
| session_id = db_manager.create_chat_session(user_id, session_name) | |
| if session_id: | |
| return SessionResponse( | |
| success=True, | |
| session_id=session_id, | |
| session_name=session_name | |
| ) | |
| else: | |
| return SessionResponse(success=False, session_id="", error="์ธ์ ์์ฑ ์คํจ") | |
| except Exception as e: | |
| logger.error(f"โ ์ธ์ ์์ฑ ์ค๋ฅ: {e}") | |
| return SessionResponse(success=False, session_id="", error=str(e)) | |
| async def list_user_sessions(user_id: str): | |
| """์ฌ์ฉ์์ ์ธ์ ๋ชฉ๋ก ์กฐํ""" | |
| try: | |
| sessions = db_manager.get_user_sessions(user_id) | |
| return { | |
| "success": True, | |
| "user_id": user_id, | |
| "sessions": sessions, | |
| "total_sessions": len(sessions) | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ธ์ ๋ชฉ๋ก ์กฐํ ์ค๋ฅ: {e}") | |
| return {"success": False, "error": str(e)} | |
| # ์ฑํ ๋ฉ์์ง ์๋ํฌ์ธํธ | |
| async def add_chat_message( | |
| session_id: str = Form(...), | |
| user_id: str = Form(...), | |
| message_type: str = Form(...), | |
| content: str = Form(...) | |
| ): | |
| """์ฑํ ๋ฉ์์ง ์ถ๊ฐ""" | |
| try: | |
| success = db_manager.add_chat_message(session_id, user_id, message_type, content) | |
| if success: | |
| return ChatMessageResponse( | |
| success=True, | |
| message_id=0, # ์ค์ ID๋ DB์์ ์๋ ์์ฑ | |
| content=content, | |
| message_type=message_type, | |
| timestamp=datetime.now().isoformat() | |
| ) | |
| else: | |
| return ChatMessageResponse( | |
| success=False, | |
| message_id=0, | |
| content="", | |
| message_type="", | |
| timestamp="", | |
| error="๋ฉ์์ง ์ถ๊ฐ ์คํจ" | |
| ) | |
| except Exception as e: | |
| logger.error(f"โ ๋ฉ์์ง ์ถ๊ฐ ์ค๋ฅ: {e}") | |
| return ChatMessageResponse( | |
| success=False, | |
| message_id=0, | |
| content="", | |
| message_type="", | |
| timestamp="", | |
| error=str(e) | |
| ) | |
| async def get_chat_history(session_id: str, limit: int = 50): | |
| """์ฑํ ํ์คํ ๋ฆฌ ์กฐํ""" | |
| try: | |
| messages = db_manager.get_chat_history(session_id, limit) | |
| return { | |
| "success": True, | |
| "session_id": session_id, | |
| "messages": messages, | |
| "total_messages": len(messages) | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ฑํ ํ์คํ ๋ฆฌ ์กฐํ ์ค๋ฅ: {e}") | |
| return {"success": False, "error": str(e)} | |
| # ๋ฌธ์ ๊ด๋ฆฌ ์๋ํฌ์ธํธ (DB ์ฐ๋) | |
| async def list_user_documents_db(user_id: str): | |
| """์ฌ์ฉ์์ ๋ฌธ์ ๋ชฉ๋ก ์กฐํ (DB ๊ธฐ๋ฐ)""" | |
| try: | |
| documents = db_manager.get_user_documents(user_id) | |
| return { | |
| "success": True, | |
| "user_id": user_id, | |
| "documents": documents, | |
| "total_documents": len(documents) | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ๋ฌธ์ ๋ชฉ๋ก ์กฐํ ์ค๋ฅ: {e}") | |
| return {"success": False, "error": str(e)} | |
| # ์ธ์ฆ ์๋ํฌ์ธํธ | |
| async def login( | |
| user_id: str = Form(...), | |
| password: str = Form(...) | |
| ): | |
| """์ฌ์ฉ์ ๋ก๊ทธ์ธ""" | |
| try: | |
| # ์ฌ์ฉ์ ์ ๋ณด ์กฐํ | |
| user_info = db_manager.get_user(user_id) | |
| if not user_info: | |
| return LoginResponse(success=False, error="์ฌ์ฉ์๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค") | |
| # ๋น๋ฐ๋ฒํธ ๊ฒ์ฆ (๊ฐ๋จํ ๊ฒ์ฆ - ์ค์ ๋ก๋ DB์ ์ ์ฅ๋ ํด์์ ๋น๊ต) | |
| if not auth_manager.verify_password(password, "dummy_hash"): # ์ค์ ๊ตฌํ์์๋ DB์ ํด์์ ๋น๊ต | |
| return LoginResponse(success=False, error="๋น๋ฐ๋ฒํธ๊ฐ ์ฌ๋ฐ๋ฅด์ง ์์ต๋๋ค") | |
| # ํ ํฐ ์์ฑ | |
| tokens = auth_manager.create_user_tokens(user_id, user_info.get('username')) | |
| return LoginResponse( | |
| success=True, | |
| access_token=tokens['access_token'], | |
| refresh_token=tokens['refresh_token'], | |
| token_type=tokens['token_type'], | |
| user_id=user_id, | |
| username=user_info.get('username') | |
| ) | |
| except Exception as e: | |
| logger.error(f"โ ๋ก๊ทธ์ธ ์ค๋ฅ: {e}") | |
| return LoginResponse(success=False, error=str(e)) | |
| async def refresh_token(refresh_token: str = Form(...)): | |
| """์ก์ธ์ค ํ ํฐ ๊ฐฑ์ """ | |
| try: | |
| new_access_token = auth_manager.refresh_access_token(refresh_token) | |
| if new_access_token: | |
| return TokenResponse( | |
| success=True, | |
| access_token=new_access_token, | |
| token_type="bearer" | |
| ) | |
| else: | |
| return TokenResponse(success=False, error="์ ํจํ์ง ์์ ๋ฆฌํ๋ ์ ํ ํฐ์ ๋๋ค") | |
| except Exception as e: | |
| logger.error(f"โ ํ ํฐ ๊ฐฑ์ ์ค๋ฅ: {e}") | |
| return TokenResponse(success=False, error=str(e)) | |
| async def register( | |
| user_id: str = Form(...), | |
| username: str = Form(...), | |
| password: str = Form(...), | |
| email: Optional[str] = Form(None) | |
| ): | |
| """์ฌ์ฉ์ ๋ฑ๋ก""" | |
| try: | |
| # ๊ธฐ์กด ์ฌ์ฉ์ ํ์ธ | |
| existing_user = db_manager.get_user(user_id) | |
| if existing_user: | |
| return LoginResponse(success=False, error="์ด๋ฏธ ์กด์ฌํ๋ ์ฌ์ฉ์ ID์ ๋๋ค") | |
| # ๋น๋ฐ๋ฒํธ ํด์ฑ | |
| hashed_password = auth_manager.hash_password(password) | |
| # ์ฌ์ฉ์ ์์ฑ (์ค์ ๊ตฌํ์์๋ hashed_password๋ฅผ DB์ ์ ์ฅ) | |
| success = db_manager.add_user(user_id, username, email) | |
| if success: | |
| # ํ ํฐ ์์ฑ | |
| tokens = auth_manager.create_user_tokens(user_id, username) | |
| return LoginResponse( | |
| success=True, | |
| access_token=tokens['access_token'], | |
| refresh_token=tokens['refresh_token'], | |
| token_type=tokens['token_type'], | |
| user_id=user_id, | |
| username=username | |
| ) | |
| else: | |
| return LoginResponse(success=False, error="์ฌ์ฉ์ ๋ฑ๋ก์ ์คํจํ์ต๋๋ค") | |
| except Exception as e: | |
| logger.error(f"โ ์ฌ์ฉ์ ๋ฑ๋ก ์ค๋ฅ: {e}") | |
| return LoginResponse(success=False, error=str(e)) | |
| async def get_current_user_info(credentials: HTTPAuthorizationCredentials = Depends(auth_manager.security)): | |
| """ํ์ฌ ์ฌ์ฉ์ ์ ๋ณด ์กฐํ""" | |
| try: | |
| user_info = auth_manager.get_current_user(credentials) | |
| return { | |
| "success": True, | |
| "user_id": user_info.get("sub"), | |
| "username": user_info.get("username"), | |
| "token_type": user_info.get("type") | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ฌ์ฉ์ ์ ๋ณด ์กฐํ ์ค๋ฅ: {e}") | |
| return {"success": False, "error": str(e)} | |
| # WebSocket ์ค์๊ฐ ์ฑํ ์๋ํฌ์ธํธ | |
| async def websocket_endpoint(websocket: WebSocket, user_id: str, session_id: str = None): | |
| """WebSocket ์ค์๊ฐ ์ฑํ ์๋ํฌ์ธํธ""" | |
| try: | |
| # ์ฐ๊ฒฐ ์๋ฝ | |
| await connection_manager.connect(websocket, user_id, session_id) | |
| # ์ฐ๊ฒฐ ์ํ ๋ธ๋ก๋์บ์คํธ | |
| await connection_manager.broadcast_message({ | |
| "type": "user_connected", | |
| "user_id": user_id, | |
| "session_id": session_id, | |
| "timestamp": datetime.now().isoformat() | |
| }, exclude_user=user_id) | |
| # ๋ฉ์์ง ์์ ๋ฃจํ | |
| while True: | |
| try: | |
| # ๋ฉ์์ง ์์ | |
| data = await websocket.receive_text() | |
| message_data = json.loads(data) | |
| # ๋ฉ์์ง ํ์ ์ ๋ฐ๋ฅธ ์ฒ๋ฆฌ | |
| message_type = message_data.get("type", "chat") | |
| if message_type == "chat": | |
| # ์ฑํ ๋ฉ์์ง ์ฒ๋ฆฌ | |
| content = message_data.get("content", "") | |
| session_id = message_data.get("session_id") | |
| # DB์ ๋ฉ์์ง ์ ์ฅ | |
| if session_id: | |
| db_manager.add_chat_message( | |
| session_id=session_id, | |
| user_id=user_id, | |
| message_type="user", | |
| content=content | |
| ) | |
| # ์ธ์ ์ ๋ค๋ฅธ ์ฌ์ฉ์๋ค์๊ฒ ๋ฉ์์ง ์ ์ก | |
| await connection_manager.send_session_message({ | |
| "type": "chat_message", | |
| "user_id": user_id, | |
| "content": content, | |
| "session_id": session_id, | |
| "timestamp": datetime.now().isoformat() | |
| }, session_id, exclude_user=user_id) | |
| # AI ์๋ต ์์ฑ (์ ํ์ ) | |
| if message_data.get("generate_ai_response", False): | |
| # AI ์๋ต ์์ฑ ๋ก์ง | |
| ai_response = await generate_ai_response(content, user_id) | |
| # AI ์๋ต์ DB์ ์ ์ฅ | |
| if session_id: | |
| db_manager.add_chat_message( | |
| session_id=session_id, | |
| user_id="ai_assistant", | |
| message_type="assistant", | |
| content=ai_response | |
| ) | |
| # AI ์๋ต์ ์ธ์ ์ฌ์ฉ์๋ค์๊ฒ ์ ์ก | |
| await connection_manager.send_session_message({ | |
| "type": "ai_response", | |
| "user_id": "ai_assistant", | |
| "content": ai_response, | |
| "session_id": session_id, | |
| "timestamp": datetime.now().isoformat() | |
| }, session_id) | |
| elif message_type == "typing": | |
| # ํ์ดํ ์ํ ์ ์ก | |
| await connection_manager.send_session_message({ | |
| "type": "user_typing", | |
| "user_id": user_id, | |
| "session_id": message_data.get("session_id"), | |
| "timestamp": datetime.now().isoformat() | |
| }, message_data.get("session_id"), exclude_user=user_id) | |
| elif message_type == "join_session": | |
| # ์ธ์ ์ฐธ์ฌ | |
| new_session_id = message_data.get("session_id") | |
| if new_session_id: | |
| # ๊ธฐ์กด ์ธ์ ์์ ์ ๊ฑฐ | |
| if user_id in connection_manager.connection_info: | |
| old_session_id = connection_manager.connection_info[user_id].get("session_id") | |
| if old_session_id and old_session_id in connection_manager.session_connections: | |
| connection_manager.session_connections[old_session_id].discard(user_id) | |
| # ์ ์ธ์ ์ ์ถ๊ฐ | |
| if new_session_id not in connection_manager.session_connections: | |
| connection_manager.session_connections[new_session_id] = set() | |
| connection_manager.session_connections[new_session_id].add(user_id) | |
| # ์ฐ๊ฒฐ ์ ๋ณด ์ ๋ฐ์ดํธ | |
| if user_id in connection_manager.connection_info: | |
| connection_manager.connection_info[user_id]["session_id"] = new_session_id | |
| # ์ธ์ ์ฐธ์ฌ ์๋ฆผ | |
| await connection_manager.send_session_message({ | |
| "type": "user_joined_session", | |
| "user_id": user_id, | |
| "session_id": new_session_id, | |
| "timestamp": datetime.now().isoformat() | |
| }, new_session_id, exclude_user=user_id) | |
| logger.info(f"๐จ WebSocket ๋ฉ์์ง ์ฒ๋ฆฌ: {user_id} - {message_type}") | |
| except WebSocketDisconnect: | |
| logger.info(f"๐ WebSocket ์ฐ๊ฒฐ ๋๊น: {user_id}") | |
| break | |
| except json.JSONDecodeError: | |
| logger.warning(f"โ ๏ธ ์๋ชป๋ JSON ํ์: {user_id}") | |
| await websocket.send_text(json.dumps({ | |
| "type": "error", | |
| "message": "์๋ชป๋ ๋ฉ์์ง ํ์์ ๋๋ค." | |
| })) | |
| except Exception as e: | |
| logger.error(f"โ WebSocket ๋ฉ์์ง ์ฒ๋ฆฌ ์ค๋ฅ: {e}") | |
| await websocket.send_text(json.dumps({ | |
| "type": "error", | |
| "message": "๋ฉ์์ง ์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค." | |
| })) | |
| except WebSocketDisconnect: | |
| logger.info(f"๐ WebSocket ์ฐ๊ฒฐ ๋๊น: {user_id}") | |
| except Exception as e: | |
| logger.error(f"โ WebSocket ์๋ํฌ์ธํธ ์ค๋ฅ: {e}") | |
| finally: | |
| # ์ฐ๊ฒฐ ํด์ | |
| connection_manager.disconnect(user_id) | |
| # ์ฐ๊ฒฐ ํด์ ์๋ฆผ | |
| await connection_manager.broadcast_message({ | |
| "type": "user_disconnected", | |
| "user_id": user_id, | |
| "timestamp": datetime.now().isoformat() | |
| }, exclude_user=user_id) | |
| async def generate_ai_response(content: str, user_id: str) -> str: | |
| """AI ์๋ต ์์ฑ (๊ฐ๋จํ ์์)""" | |
| try: | |
| # ํ์ฌ ๋ชจ๋ธ๋ก ์๋ต ์์ฑ | |
| response = await generate_sync(content, user_id) | |
| return response.get("response", "์ฃ์กํฉ๋๋ค. ์๋ต์ ์์ฑํ ์ ์์ต๋๋ค.") | |
| except Exception as e: | |
| logger.error(f"โ AI ์๋ต ์์ฑ ์คํจ: {e}") | |
| return "์ฃ์กํฉ๋๋ค. ์๋ต์ ์์ฑํ ์ ์์ต๋๋ค." | |
| # WebSocket ์ํ ์กฐํ ์๋ํฌ์ธํธ | |
| async def get_websocket_status(): | |
| """WebSocket ์ฐ๊ฒฐ ์ํ ์กฐํ""" | |
| return { | |
| "active_connections": connection_manager.get_connection_count(), | |
| "active_users": connection_manager.get_active_users(), | |
| "sessions": list(connection_manager.session_connections.keys()) | |
| } | |
| # Celery ๋ฐฑ๊ทธ๋ผ์ด๋ ์์ ์๋ํฌ์ธํธ | |
| async def start_document_processing( | |
| user_id: str = Form(...), | |
| document_id: str = Form(...), | |
| file_path: str = Form(...), | |
| file_type: str = Form(...) | |
| ): | |
| """๋ฌธ์ ์ฒ๋ฆฌ ๋ฐฑ๊ทธ๋ผ์ด๋ ์์ ์์""" | |
| try: | |
| task = process_document_async.delay(user_id, document_id, file_path, file_type) | |
| return { | |
| "success": True, | |
| "task_id": task.id, | |
| "status": "started", | |
| "message": "๋ฌธ์ ์ฒ๋ฆฌ ์์ ์ด ์์๋์์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ๋ฌธ์ ์ฒ๋ฆฌ ์์ ์์ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| async def start_ai_generation( | |
| user_id: str = Form(...), | |
| session_id: str = Form(...), | |
| prompt: str = Form(...), | |
| model_id: Optional[str] = Form(None) | |
| ): | |
| """AI ์๋ต ์์ฑ ๋ฐฑ๊ทธ๋ผ์ด๋ ์์ ์์""" | |
| try: | |
| task = generate_ai_response_async.delay(user_id, session_id, prompt, model_id) | |
| return { | |
| "success": True, | |
| "task_id": task.id, | |
| "status": "started", | |
| "message": "AI ์๋ต ์์ฑ ์์ ์ด ์์๋์์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ AI ์๋ต ์์ฑ ์์ ์์ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| async def start_rag_query( | |
| user_id: str = Form(...), | |
| query: str = Form(...), | |
| document_id: str = Form(...) | |
| ): | |
| """RAG ์ฟผ๋ฆฌ ๋ฐฑ๊ทธ๋ผ์ด๋ ์์ ์์""" | |
| try: | |
| task = rag_query_async.delay(user_id, query, document_id) | |
| return { | |
| "success": True, | |
| "task_id": task.id, | |
| "status": "started", | |
| "message": "RAG ์ฟผ๋ฆฌ ์์ ์ด ์์๋์์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ RAG ์ฟผ๋ฆฌ ์์ ์์ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| async def start_batch_processing( | |
| user_id: str = Form(...), | |
| document_ids: str = Form(...) # JSON ๋ฌธ์์ด๋ก ์ ๋ฌ | |
| ): | |
| """๋ฌธ์ ์ผ๊ด ์ฒ๋ฆฌ ๋ฐฑ๊ทธ๋ผ์ด๋ ์์ ์์""" | |
| try: | |
| import json | |
| doc_ids = json.loads(document_ids) | |
| task = batch_process_documents_async.delay(user_id, doc_ids) | |
| return { | |
| "success": True, | |
| "task_id": task.id, | |
| "status": "started", | |
| "message": f"๋ฌธ์ ์ผ๊ด ์ฒ๋ฆฌ ์์ ์ด ์์๋์์ต๋๋ค. ({len(doc_ids)}๊ฐ ๋ฌธ์)" | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ๋ฌธ์ ์ผ๊ด ์ฒ๋ฆฌ ์์ ์์ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| async def get_task_status_endpoint(task_id: str): | |
| """์์ ์ํ ์กฐํ""" | |
| try: | |
| status = get_task_status(task_id) | |
| if status: | |
| return { | |
| "success": True, | |
| "task_id": task_id, | |
| "status": status["status"], | |
| "result": status["result"], | |
| "info": status["info"] | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "error": "์์ ์ ์ฐพ์ ์ ์์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์์ ์ํ ์กฐํ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| async def cancel_task_endpoint(task_id: str): | |
| """์์ ์ทจ์""" | |
| try: | |
| success = cancel_task(task_id) | |
| if success: | |
| return { | |
| "success": True, | |
| "task_id": task_id, | |
| "message": "์์ ์ด ์ทจ์๋์์ต๋๋ค." | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "error": "์์ ์ทจ์์ ์คํจํ์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์์ ์ทจ์ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| # ์ฑ๋ฅ ๋ชจ๋ํฐ๋ง ์๋ํฌ์ธํธ | |
| async def start_performance_monitoring(): | |
| """์ฑ๋ฅ ๋ชจ๋ํฐ๋ง ์์""" | |
| try: | |
| performance_monitor.start_monitoring() | |
| return {"message": "์ฑ๋ฅ ๋ชจ๋ํฐ๋ง์ด ์์๋์์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"๋ชจ๋ํฐ๋ง ์์ ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=f"๋ชจ๋ํฐ๋ง ์์ ์คํจ: {str(e)}") | |
| async def stop_performance_monitoring(): | |
| """์ฑ๋ฅ ๋ชจ๋ํฐ๋ง ์ค์ง""" | |
| try: | |
| performance_monitor.stop_monitoring() | |
| return {"message": "์ฑ๋ฅ ๋ชจ๋ํฐ๋ง์ด ์ค์ง๋์์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"๋ชจ๋ํฐ๋ง ์ค์ง ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=f"๋ชจ๋ํฐ๋ง ์ค์ง ์คํจ: {str(e)}") | |
| async def get_monitoring_status(): | |
| """๋ชจ๋ํฐ๋ง ์ํ ์กฐํ""" | |
| try: | |
| summary = performance_monitor.get_performance_summary() | |
| return summary | |
| except Exception as e: | |
| logger.error(f"๋ชจ๋ํฐ๋ง ์ํ ์กฐํ ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=f"๋ชจ๋ํฐ๋ง ์ํ ์กฐํ ์คํจ: {str(e)}") | |
| async def get_system_health(): | |
| """์์คํ ๊ฑด๊ฐ ์ํ ์กฐํ""" | |
| try: | |
| health = performance_monitor.get_system_health() | |
| return { | |
| "status": health.status, | |
| "cpu_health": health.cpu_health, | |
| "memory_health": health.memory_health, | |
| "disk_health": health.disk_health, | |
| "network_health": health.network_health, | |
| "recommendations": health.recommendations | |
| } | |
| except Exception as e: | |
| logger.error(f"์์คํ ๊ฑด๊ฐ ์ํ ์กฐํ ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=f"์์คํ ๊ฑด๊ฐ ์ํ ์กฐํ ์คํจ: {str(e)}") | |
| async def export_performance_metrics(file_path: str = "performance_metrics.json"): | |
| """์ฑ๋ฅ ๋ฉํธ๋ฆญ ๋ด๋ณด๋ด๊ธฐ""" | |
| try: | |
| performance_monitor.export_metrics(file_path) | |
| return {"message": f"์ฑ๋ฅ ๋ฉํธ๋ฆญ์ด {file_path}์ ์ ์ฅ๋์์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"๋ฉํธ๋ฆญ ๋ด๋ณด๋ด๊ธฐ ์คํจ: {e}") | |
| raise HTTPException(status_code=500, detail=f"๋ฉํธ๋ฆญ ๋ด๋ณด๋ด๊ธฐ ์คํจ: {str(e)}") | |
| # ============================================================================ | |
| # ์ด๋ฏธ์ง OCR ์ ์ฉ API ์๋ํฌ์ธํธ (๊ธฐ์กด ํ ์คํธ ๊ธฐ๋ฐ ์์คํ ๊ณผ ์์ ํ ๋ถ๋ฆฌ) | |
| # ============================================================================ | |
| async def upload_image_document( | |
| file: UploadFile = File(...), | |
| user_id: str = Form("default_user"), | |
| document_id: Optional[str] = Form(None) | |
| ): | |
| """์ด๋ฏธ์ง OCR ์ ์ฉ ๋ฌธ์ ์ ๋ก๋""" | |
| start_time = time.time() | |
| try: | |
| # ๋ฌธ์ ID ์์ฑ (์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ) | |
| if not document_id: | |
| import uuid | |
| document_id = str(uuid.uuid4())[:8] | |
| # ์์ ํ์ผ ์ ์ฅ | |
| temp_file_path = f"./temp_image_{document_id}_{file.filename}" | |
| with open(temp_file_path, "wb") as f: | |
| content = await file.read() | |
| f.write(content) | |
| # ์ด๋ฏธ์ง OCR ์ฒ๋ฆฌ ๋ฐ ๋ฒกํฐ ์คํ ์ด์ ์ ์ฅ | |
| result = image_rag_processor.process_and_store_image_document( | |
| user_id, document_id, temp_file_path | |
| ) | |
| # ์์ ํ์ผ ์ญ์ | |
| import os | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| processing_time = time.time() - start_time | |
| logger.info(f"๐ผ๏ธ ์ด๋ฏธ์ง OCR ๋ฌธ์ ์ ๋ก๋ ์๋ฃ ({processing_time:.2f}์ด): {file.filename}") | |
| return DocumentUploadResponse( | |
| success=result["success"], | |
| document_id=document_id, | |
| message=result.get("message", ""), | |
| chunks=result.get("chunks"), | |
| latex_count=result.get("latex_count"), | |
| error=result.get("error"), | |
| auto_response=result.get("auto_response", "") | |
| ) | |
| except Exception as e: | |
| logger.error(f"โ ์ด๋ฏธ์ง OCR ๋ฌธ์ ์ ๋ก๋ ์คํจ: {e}") | |
| return DocumentUploadResponse( | |
| success=False, | |
| document_id=document_id if 'document_id' in locals() else "unknown", | |
| message="์ด๋ฏธ์ง OCR ๋ฌธ์ ์ ๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค.", | |
| error=str(e) | |
| ) | |
| async def generate_image_ocr_response( | |
| query: str = Form(...), | |
| user_id: str = Form("default_user"), | |
| document_id: str = Form(...) | |
| ): | |
| """์ด๋ฏธ์ง OCR ๊ธฐ๋ฐ RAG ์๋ต ์์ฑ""" | |
| start_time = time.time() | |
| try: | |
| # ์ด๋ฏธ์ง OCR RAG ์๋ต ์์ฑ | |
| result = image_rag_processor.generate_image_rag_response( | |
| user_id, document_id, query | |
| ) | |
| processing_time = time.time() - start_time | |
| result["processing_time"] = processing_time | |
| logger.info(f"๐ผ๏ธ ์ด๋ฏธ์ง OCR RAG ์๋ต ์์ฑ ์๋ฃ ({processing_time:.2f}์ด)") | |
| return result | |
| except Exception as e: | |
| logger.error(f"โ ์ด๋ฏธ์ง OCR RAG ์๋ต ์์ฑ ์คํจ: {e}") | |
| return RAGResponse( | |
| success=False, | |
| response=f"์ด๋ฏธ์ง OCR RAG ์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}", | |
| context="", | |
| sources=[], | |
| search_results=0, | |
| processing_time=time.time() - start_time | |
| ) | |
| async def get_image_document_info(user_id: str, document_id: str): | |
| """์ด๋ฏธ์ง OCR ๋ฌธ์ ์ ๋ณด ์กฐํ""" | |
| try: | |
| result = image_rag_processor.get_image_document_info(user_id, document_id) | |
| return result | |
| except Exception as e: | |
| logger.error(f"โ ์ด๋ฏธ์ง OCR ๋ฌธ์ ์ ๋ณด ์กฐํ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| async def delete_image_document(user_id: str, document_id: str): | |
| """์ด๋ฏธ์ง OCR ๋ฌธ์ ์ญ์ """ | |
| try: | |
| # ๋ฒกํฐ ์คํ ์ด์์ ๋ฌธ์ ์ญ์ | |
| success = vector_store_manager.delete_document(user_id, document_id) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "์ด๋ฏธ์ง OCR ๋ฌธ์๊ฐ ์ญ์ ๋์์ต๋๋ค." | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "error": "์ด๋ฏธ์ง OCR ๋ฌธ์ ์ญ์ ์ ์คํจํ์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ด๋ฏธ์ง OCR ๋ฌธ์ ์ญ์ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| # ============================================================================ | |
| # LaTeX-OCR ์ ์ฉ API ์๋ํฌ์ธํธ (์ํ ์์ ์ธ์ ๊ธฐ๋ฅ ํฌํจ) | |
| # ============================================================================ | |
| async def upload_latex_document( | |
| file: UploadFile = File(...), | |
| user_id: str = Form("default_user"), | |
| document_id: Optional[str] = Form(None) | |
| ): | |
| """LaTeX-OCR ์ ์ฉ ๋ฌธ์ ์ ๋ก๋""" | |
| start_time = time.time() | |
| try: | |
| # ๋ฌธ์ ID ์์ฑ (์ ๊ณต๋์ง ์์ ๊ฒฝ์ฐ) | |
| if not document_id: | |
| import uuid | |
| document_id = str(uuid.uuid4())[:8] | |
| # ์์ ํ์ผ ์ ์ฅ | |
| temp_file_path = f"./temp_latex_{document_id}_{file.filename}" | |
| with open(temp_file_path, "wb") as f: | |
| content = await file.read() | |
| f.write(content) | |
| # LaTeX-OCR ์ฒ๋ฆฌ ๋ฐ ๋ฒกํฐ ์คํ ์ด์ ์ ์ฅ | |
| result = latex_rag_processor.process_and_store_latex_document( | |
| user_id, document_id, temp_file_path | |
| ) | |
| # ์์ ํ์ผ ์ญ์ | |
| import os | |
| if os.path.exists(temp_file_path): | |
| os.remove(temp_file_path) | |
| processing_time = time.time() - start_time | |
| logger.info(f"๐งฎ LaTeX-OCR ๋ฌธ์ ์ ๋ก๋ ์๋ฃ ({processing_time:.2f}์ด): {file.filename}") | |
| return DocumentUploadResponse( | |
| success=result["success"], | |
| document_id=document_id, | |
| message=result.get("message", ""), | |
| chunks=result.get("chunks"), | |
| latex_count=result.get("latex_count"), | |
| error=result.get("error"), | |
| auto_response=result.get("auto_response", "") | |
| ) | |
| except Exception as e: | |
| logger.error(f"โ LaTeX-OCR ๋ฌธ์ ์ ๋ก๋ ์คํจ: {e}") | |
| return DocumentUploadResponse( | |
| success=False, | |
| document_id=document_id if 'document_id' in locals() else "unknown", | |
| message="LaTeX-OCR ๋ฌธ์ ์ ๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค.", | |
| error=str(e) | |
| ) | |
| async def generate_latex_ocr_response( | |
| query: str = Form(...), | |
| user_id: str = Form("default_user"), | |
| document_id: str = Form(...) | |
| ): | |
| """LaTeX-OCR ๊ธฐ๋ฐ RAG ์๋ต ์์ฑ""" | |
| start_time = time.time() | |
| try: | |
| # LaTeX-OCR RAG ์๋ต ์์ฑ | |
| result = latex_rag_processor.generate_latex_rag_response( | |
| user_id, document_id, query | |
| ) | |
| processing_time = time.time() - start_time | |
| result["processing_time"] = processing_time | |
| logger.info(f"๐งฎ LaTeX-OCR RAG ์๋ต ์์ฑ ์๋ฃ ({processing_time:.2f}์ด)") | |
| return result | |
| except Exception as e: | |
| logger.error(f"โ LaTeX-OCR RAG ์๋ต ์์ฑ ์คํจ: {e}") | |
| return RAGResponse( | |
| success=False, | |
| response=f"LaTeX-OCR RAG ์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}", | |
| context="", | |
| sources=[], | |
| search_results=0, | |
| processing_time=time.time() - start_time | |
| ) | |
| async def get_latex_document_info(user_id: str, document_id: str): | |
| """LaTeX-OCR ๋ฌธ์ ์ ๋ณด ์กฐํ""" | |
| try: | |
| result = latex_rag_processor.get_latex_document_info(user_id, document_id) | |
| return result | |
| except Exception as e: | |
| logger.error(f"โ LaTeX-OCR ๋ฌธ์ ์ ๋ณด ์กฐํ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| async def delete_latex_document(user_id: str, document_id: str): | |
| """LaTeX-OCR ๋ฌธ์ ์ญ์ """ | |
| try: | |
| # ๋ฒกํฐ ์คํ ์ด์์ ๋ฌธ์ ์ญ์ | |
| success = vector_store_manager.delete_document(user_id, document_id) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "LaTeX-OCR ๋ฌธ์๊ฐ ์ญ์ ๋์์ต๋๋ค." | |
| } | |
| else: | |
| return { | |
| "success": False, | |
| "error": "LaTeX-OCR ๋ฌธ์ ์ญ์ ์ ์คํจํ์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ LaTeX-OCR ๋ฌธ์ ์ญ์ ์คํจ: {e}") | |
| return { | |
| "success": False, | |
| "error": str(e) | |
| } | |
| # ============================================================================ | |
| # LaTeX-OCR + FAISS ํตํฉ ์์คํ ์๋ํฌ์ธํธ | |
| # ============================================================================ | |
| # # LaTeX-OCR + FAISS ์์คํ ์ด๊ธฐํ | |
| # latex_ocr_faiss_simple = None | |
| # latex_ocr_faiss_integrated = None | |
| # def init_latex_ocr_faiss_systems(): | |
| # """LaTeX-OCR + FAISS ์์คํ ์ด๊ธฐํ""" | |
| # global latex_ocr_faiss_simple, latex_ocr_faiss_integrated | |
| # try: | |
| # latex_ocr_faiss_simple = LatexOCRFAISSSimple() | |
| # latex_ocr_faiss_integrated = LatexOCRFAISSIntegrated() | |
| # logger.info("โ LaTeX-OCR + FAISS ์์คํ ์ด๊ธฐํ ์๋ฃ") | |
| # except Exception as e: | |
| # logger.error(f"โ LaTeX-OCR + FAISS ์์คํ ์ด๊ธฐํ ์คํจ: {e}") | |
| async def process_pdf_with_latex_faiss( | |
| file: UploadFile = File(...), | |
| user_id: str = Form("default_user"), | |
| system_type: str = Form("simple") # "simple" ๋๋ "integrated" | |
| ): | |
| """PDF์์ LaTeX ์์ ์ถ์ถ ๋ฐ FAISS ์ ์ฅ""" | |
| try: | |
| # ํ์ผ ์ ์ฅ | |
| upload_dir = Path("uploads/latex_ocr_faiss") | |
| upload_dir.mkdir(parents=True, exist_ok=True) | |
| file_path = upload_dir / f"{user_id}_{file.filename}" | |
| with open(file_path, "wb") as f: | |
| content = await file.read() | |
| f.write(content) | |
| # ์์คํ ์ ํ (์ฃผ์ ์ฒ๋ฆฌ๋จ - ์ญ์ ๋ ๋ชจ๋) | |
| # if system_type == "simple": | |
| # if not latex_ocr_faiss_simple: | |
| # init_latex_ocr_faiss_systems() | |
| # system = latex_ocr_faiss_simple | |
| # else: | |
| # if not latex_ocr_faiss_integrated: | |
| # init_latex_ocr_faiss_systems() | |
| # system = latex_ocr_faiss_integrated | |
| # ์์ ์ฒ๋ฆฌ - ๊ธฐ๋ฅ ๋นํ์ฑํ | |
| return DocumentUploadResponse( | |
| success=False, | |
| document_id="", | |
| message="LaTeX-OCR + FAISS ๊ธฐ๋ฅ์ด ํ์ฌ ๋นํ์ฑํ๋์ด ์์ต๋๋ค", | |
| error="์ญ์ ๋ ๋ชจ๋๋ก ์ธํด ๊ธฐ๋ฅ์ด ๋นํ์ฑํ๋จ" | |
| ) | |
| # PDF ์ฒ๋ฆฌ (์ฃผ์ ์ฒ๋ฆฌ๋จ) | |
| # result = system.process_pdf_with_latex(str(file_path), user_id) | |
| if result["success"]: | |
| return DocumentUploadResponse( | |
| success=True, | |
| document_id=f"latex_ocr_faiss_{user_id}_{file.filename}", | |
| message=f"LaTeX ์์ {result['latex_count']}๊ฐ ์ถ์ถ ์๋ฃ", | |
| chunks=result['latex_count'], | |
| latex_count=result['latex_count'] | |
| ) | |
| else: | |
| return DocumentUploadResponse( | |
| success=False, | |
| document_id="", | |
| message="LaTeX ์์ ์ถ์ถ ์คํจ", | |
| error=result.get("error", "LaTeX ์์ ์ถ์ถ ์คํจ") | |
| ) | |
| except Exception as e: | |
| logger.error(f"LaTeX-OCR + FAISS ์ฒ๋ฆฌ ์ค๋ฅ: {e}") | |
| return DocumentUploadResponse( | |
| success=False, | |
| document_id="", | |
| message="์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค", | |
| error=f"์ฒ๋ฆฌ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" | |
| ) | |
| async def search_latex_formulas( | |
| query: str = Form(...), | |
| user_id: str = Form("default_user"), | |
| document_path: Optional[str] = Form(None), | |
| system_type: str = Form("simple"), | |
| k: int = Form(5) | |
| ): | |
| """์ ์ฅ๋ LaTeX ์์ ๊ฒ์""" | |
| try: | |
| # ์์คํ ์ ํ (์ฃผ์ ์ฒ๋ฆฌ๋จ - ์ญ์ ๋ ๋ชจ๋) | |
| # if system_type == "simple": | |
| # if not latex_ocr_faiss_simple: | |
| # init_latex_ocr_faiss_systems() | |
| # system = latex_ocr_faiss_simple | |
| # else: | |
| # if not latex_ocr_faiss_integrated: | |
| # init_latex_ocr_faiss_systems() | |
| # system = latex_ocr_faiss_integrated | |
| # ์์ ์ฒ๋ฆฌ - ๊ธฐ๋ฅ ๋นํ์ฑํ | |
| return RAGResponse( | |
| success=False, | |
| response="LaTeX-OCR + FAISS ๊ฒ์ ๊ธฐ๋ฅ์ด ํ์ฌ ๋นํ์ฑํ๋์ด ์์ต๋๋ค", | |
| context="", | |
| sources=[], | |
| search_results=0, | |
| processing_time=0.0, | |
| error="์ญ์ ๋ ๋ชจ๋๋ก ์ธํด ๊ธฐ๋ฅ์ด ๋นํ์ฑํ๋จ" | |
| ) | |
| # ์์ ๊ฒ์ (์ฃผ์ ์ฒ๋ฆฌ๋จ) | |
| # search_result = system.search_formulas(query, user_id, document_path, k) | |
| if search_result["success"]: | |
| # ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์๋ต ํ์์ผ๋ก ๋ณํ | |
| context = "\n".join([f"์์: {result['formula']} (์ ์ฌ๋: {result['similarity']:.3f})" | |
| for result in search_result['results']]) | |
| sources = [{"formula": result['formula'], "similarity": result['similarity'], | |
| "page": result.get('page', 1)} for result in search_result['results']] | |
| return RAGResponse( | |
| success=True, | |
| response=f"๊ฒ์๋ ์์ {search_result['search_results']}๊ฐ๋ฅผ ์ฐพ์์ต๋๋ค.", | |
| context=context, | |
| sources=sources, | |
| search_results=search_result['search_results'], | |
| processing_time=0.0 # ์ค์ ์ฒ๋ฆฌ ์๊ฐ ์ธก์ ํ์ | |
| ) | |
| else: | |
| return RAGResponse( | |
| success=False, | |
| response="์์ ๊ฒ์์ ์คํจํ์ต๋๋ค.", | |
| context="", | |
| sources=[], | |
| search_results=0, | |
| processing_time=0.0, | |
| error=search_result.get("error", "๊ฒ์ ์คํจ") | |
| ) | |
| except Exception as e: | |
| logger.error(f"LaTeX ์์ ๊ฒ์ ์ค๋ฅ: {e}") | |
| return RAGResponse( | |
| success=False, | |
| response="๊ฒ์ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค.", | |
| context="", | |
| sources=[], | |
| search_results=0, | |
| processing_time=0.0, | |
| error=str(e) | |
| ) | |
| async def get_latex_ocr_faiss_status(): | |
| """LaTeX-OCR + FAISS ์์คํ ์ํ ํ์ธ""" | |
| try: | |
| # ์ฃผ์ ์ฒ๋ฆฌ๋จ - ์ญ์ ๋ ๋ชจ๋ | |
| # simple_status = latex_ocr_faiss_simple is not None | |
| # integrated_status = latex_ocr_faiss_integrated is not None | |
| # return { | |
| # "simple_system_initialized": simple_status, | |
| # "integrated_system_initialized": integrated_status, | |
| # "status": "ready" if (simple_status or integrated_status) else "not_initialized" | |
| # } | |
| return { | |
| "simple_system_initialized": False, | |
| "integrated_system_initialized": False, | |
| "status": "disabled", | |
| "message": "LaTeX-OCR + FAISS ๊ธฐ๋ฅ์ด ํ์ฌ ๋นํ์ฑํ๋์ด ์์ต๋๋ค" | |
| } | |
| except Exception as e: | |
| logger.error(f"์ํ ํ์ธ ์ค๋ฅ: {e}") | |
| return {"status": "error", "error": str(e)} | |
| # ============================================================================ | |
| # ์ปจํ ์คํธ ๊ด๋ฆฌ ์์คํ ์๋ํฌ์ธํธ | |
| # ============================================================================ | |
| async def set_system_prompt(prompt: str = Form(...)): | |
| """์์คํ ํ๋กฌํํธ ์ค์ """ | |
| try: | |
| context_manager.set_system_prompt(prompt) | |
| return { | |
| "success": True, | |
| "message": "์์คํ ํ๋กฌํํธ๊ฐ ์ค์ ๋์์ต๋๋ค.", | |
| "prompt_length": len(prompt) | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์์คํ ํ๋กฌํํธ ์ค์ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def add_context_message( | |
| role: str = Form(...), # 'user' ๋๋ 'assistant' | |
| content: str = Form(...), | |
| message_id: str = Form(None), | |
| metadata: str = Form("{}") # JSON ๋ฌธ์์ด | |
| ): | |
| """์ปจํ ์คํธ์ ๋ฉ์์ง ์ถ๊ฐ""" | |
| try: | |
| import json | |
| metadata_dict = json.loads(metadata) if metadata else {} | |
| if role == "user": | |
| msg_id = context_manager.add_user_message(content, message_id, metadata_dict) | |
| elif role == "assistant": | |
| msg_id = context_manager.add_assistant_message(content, message_id, metadata_dict) | |
| else: | |
| return {"success": False, "error": "์๋ชป๋ ์ญํ ์ ๋๋ค. 'user' ๋๋ 'assistant'๋ฅผ ์ฌ์ฉํ์ธ์."} | |
| return { | |
| "success": True, | |
| "message": "๋ฉ์์ง๊ฐ ์ปจํ ์คํธ์ ์ถ๊ฐ๋์์ต๋๋ค.", | |
| "message_id": msg_id, | |
| "context_summary": context_manager.get_context_summary() | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ปจํ ์คํธ ๋ฉ์์ง ์ถ๊ฐ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def get_context( | |
| include_system: bool = True, | |
| max_length: Optional[int] = None, | |
| recent_turns: Optional[int] = None | |
| ): | |
| """ํ์ฌ ์ปจํ ์คํธ ์กฐํ""" | |
| try: | |
| if recent_turns: | |
| context = context_manager.get_recent_context(recent_turns) | |
| else: | |
| context = context_manager.get_context(include_system, max_length) | |
| return { | |
| "success": True, | |
| "context": context, | |
| "context_summary": context_manager.get_context_summary(), | |
| "memory_efficiency": context_manager.get_memory_efficiency() | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ปจํ ์คํธ ์กฐํ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def get_context_summary(): | |
| """์ปจํ ์คํธ ์์ฝ ์ ๋ณด ์กฐํ""" | |
| try: | |
| return { | |
| "success": True, | |
| "summary": context_manager.get_context_summary(), | |
| "memory_efficiency": context_manager.get_memory_efficiency() | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ปจํ ์คํธ ์์ฝ ์กฐํ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def clear_context(): | |
| """์ปจํ ์คํธ ์ด๊ธฐํ""" | |
| try: | |
| context_manager.clear_context() | |
| return { | |
| "success": True, | |
| "message": "์ปจํ ์คํธ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ปจํ ์คํธ ์ด๊ธฐํ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def remove_context_message(message_id: str): | |
| """์ปจํ ์คํธ์์ ํน์ ๋ฉ์์ง ์ ๊ฑฐ""" | |
| try: | |
| success = context_manager.remove_message(message_id) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "๋ฉ์์ง๊ฐ ์ ๊ฑฐ๋์์ต๋๋ค.", | |
| "context_summary": context_manager.get_context_summary() | |
| } | |
| else: | |
| return {"success": False, "error": "๋ฉ์์ง๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ ๋ฉ์์ง ์ ๊ฑฐ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def edit_context_message( | |
| message_id: str, | |
| new_content: str = Form(...) | |
| ): | |
| """์ปจํ ์คํธ ๋ฉ์์ง ์์ """ | |
| try: | |
| success = context_manager.edit_message(message_id, new_content) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "๋ฉ์์ง๊ฐ ์์ ๋์์ต๋๋ค.", | |
| "context_summary": context_manager.get_context_summary() | |
| } | |
| else: | |
| return {"success": False, "error": "๋ฉ์์ง๋ฅผ ์ฐพ์ ์ ์์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ ๋ฉ์์ง ์์ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def search_context(query: str, max_results: int = 5): | |
| """์ปจํ ์คํธ ๋ด์์ ๊ฒ์""" | |
| try: | |
| results = context_manager.search_context(query, max_results) | |
| return { | |
| "success": True, | |
| "query": query, | |
| "results": results, | |
| "total_results": len(results) | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ปจํ ์คํธ ๊ฒ์ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def export_context(file_path: str = Form(None)): | |
| """์ปจํ ์คํธ ๋ด๋ณด๋ด๊ธฐ""" | |
| try: | |
| exported_path = context_manager.export_context(file_path) | |
| return { | |
| "success": True, | |
| "message": "์ปจํ ์คํธ๊ฐ ๋ด๋ณด๋ด์ก์ต๋๋ค.", | |
| "file_path": exported_path | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ์ปจํ ์คํธ ๋ด๋ณด๋ด๊ธฐ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def import_context(file_path: str = Form(...)): | |
| """์ปจํ ์คํธ ๊ฐ์ ธ์ค๊ธฐ""" | |
| try: | |
| success = context_manager.import_context(file_path) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "์ปจํ ์คํธ๊ฐ ๊ฐ์ ธ์์ก์ต๋๋ค.", | |
| "context_summary": context_manager.get_context_summary() | |
| } | |
| else: | |
| return {"success": False, "error": "์ปจํ ์คํธ ๊ฐ์ ธ์ค๊ธฐ์ ์คํจํ์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ ์ปจํ ์คํธ ๊ฐ์ ธ์ค๊ธฐ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| # ============================================================================ | |
| # LoRA/QLoRA ๊ด๋ฆฌ ์์คํ ์๋ํฌ์ธํธ | |
| # ============================================================================ | |
| async def load_lora_base_model( | |
| model_path: str = Form(...), | |
| model_type: str = Form("causal_lm") | |
| ): | |
| """LoRA ๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| success = lora_manager.load_base_model(model_path, model_type) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "๊ธฐ๋ณธ ๋ชจ๋ธ์ด ๋ก๋๋์์ต๋๋ค.", | |
| "model_path": model_path, | |
| "device": lora_manager.device | |
| } | |
| else: | |
| return {"success": False, "error": "๋ชจ๋ธ ๋ก๋์ ์คํจํ์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ LoRA ๊ธฐ๋ณธ ๋ชจ๋ธ ๋ก๋ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def create_lora_config( | |
| r: int = Form(16), | |
| lora_alpha: int = Form(32), | |
| target_modules: str = Form("q_proj,v_proj,k_proj,o_proj,gate_proj,up_proj,down_proj"), | |
| lora_dropout: float = Form(0.1), | |
| bias: str = Form("none"), | |
| task_type: str = Form("CAUSAL_LM") | |
| ): | |
| """LoRA ์ค์ ์์ฑ""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| # target_modules๋ฅผ ๋ฆฌ์คํธ๋ก ๋ณํ | |
| target_modules_list = target_modules.split(",") if target_modules else None | |
| config = lora_manager.create_lora_config( | |
| r=r, | |
| lora_alpha=lora_alpha, | |
| target_modules=target_modules_list, | |
| lora_dropout=lora_dropout, | |
| bias=bias, | |
| task_type=task_type | |
| ) | |
| return { | |
| "success": True, | |
| "message": "LoRA ์ค์ ์ด ์์ฑ๋์์ต๋๋ค.", | |
| "config": config.to_dict() | |
| } | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์ค์ ์์ฑ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def apply_lora_adapter(adapter_name: str = Form("default")): | |
| """LoRA ์ด๋ํฐ๋ฅผ ๋ชจ๋ธ์ ์ ์ฉ""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| success = lora_manager.apply_lora_to_model(adapter_name) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "LoRA ์ด๋ํฐ๊ฐ ์ ์ฉ๋์์ต๋๋ค.", | |
| "adapter_name": adapter_name, | |
| "stats": lora_manager.get_adapter_stats() | |
| } | |
| else: | |
| return {"success": False, "error": "LoRA ์ด๋ํฐ ์ ์ฉ์ ์คํจํ์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์ด๋ํฐ ์ ์ฉ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def load_lora_adapter( | |
| adapter_path: str = Form(...), | |
| adapter_name: str = Form(None) | |
| ): | |
| """์ ์ฅ๋ LoRA ์ด๋ํฐ ๋ก๋""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| success = lora_manager.load_lora_adapter(adapter_path, adapter_name) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "LoRA ์ด๋ํฐ๊ฐ ๋ก๋๋์์ต๋๋ค.", | |
| "adapter_name": lora_manager.current_adapter_name, | |
| "stats": lora_manager.get_adapter_stats() | |
| } | |
| else: | |
| return {"success": False, "error": "LoRA ์ด๋ํฐ ๋ก๋์ ์คํจํ์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์ด๋ํฐ ๋ก๋ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def save_lora_adapter( | |
| adapter_name: str = Form(None), | |
| output_dir: str = Form(None) | |
| ): | |
| """LoRA ์ด๋ํฐ ์ ์ฅ""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| success = lora_manager.save_lora_adapter(adapter_name, output_dir) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "LoRA ์ด๋ํฐ๊ฐ ์ ์ฅ๋์์ต๋๋ค.", | |
| "adapter_name": lora_manager.current_adapter_name | |
| } | |
| else: | |
| return {"success": False, "error": "LoRA ์ด๋ํฐ ์ ์ฅ์ ์คํจํ์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์ด๋ํฐ ์ ์ฅ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def list_lora_adapters(): | |
| """์ฌ์ฉ ๊ฐ๋ฅํ LoRA ์ด๋ํฐ ๋ชฉ๋ก""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| adapters = lora_manager.list_available_adapters() | |
| return { | |
| "success": True, | |
| "adapters": adapters | |
| } | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์ด๋ํฐ ๋ชฉ๋ก ์กฐํ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def get_lora_stats(): | |
| """ํ์ฌ LoRA ์ด๋ํฐ ํต๊ณ""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| stats = lora_manager.get_adapter_stats() | |
| return { | |
| "success": True, | |
| "stats": stats | |
| } | |
| except Exception as e: | |
| logger.error(f"โ LoRA ํต๊ณ ์กฐํ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def switch_lora_adapter(adapter_name: str = Form(...)): | |
| """LoRA ์ด๋ํฐ ์ ํ""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| success = lora_manager.switch_adapter(adapter_name) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": f"LoRA ์ด๋ํฐ๊ฐ {adapter_name}์ผ๋ก ์ ํ๋์์ต๋๋ค.", | |
| "adapter_name": adapter_name, | |
| "stats": lora_manager.get_adapter_stats() | |
| } | |
| else: | |
| return {"success": False, "error": "LoRA ์ด๋ํฐ ์ ํ์ ์คํจํ์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์ด๋ํฐ ์ ํ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def unload_lora_adapter(): | |
| """LoRA ์ด๋ํฐ ์ธ๋ก๋""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| success = lora_manager.unload_adapter() | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "LoRA ์ด๋ํฐ๊ฐ ์ธ๋ก๋๋์์ต๋๋ค." | |
| } | |
| else: | |
| return {"success": False, "error": "LoRA ์ด๋ํฐ ์ธ๋ก๋์ ์คํจํ์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์ด๋ํฐ ์ธ๋ก๋ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def generate_with_lora( | |
| prompt: str = Form(...), | |
| max_length: int = Form(100), | |
| temperature: float = Form(0.7) | |
| ): | |
| """LoRA ๋ชจ๋ธ์ ์ฌ์ฉํ ํ ์คํธ ์์ฑ""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| response = lora_manager.generate_text(prompt, max_length, temperature) | |
| return { | |
| "success": True, | |
| "response": response, | |
| "adapter_name": lora_manager.current_adapter_name | |
| } | |
| except Exception as e: | |
| logger.error(f"โ LoRA ํ ์คํธ ์์ฑ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def merge_lora_with_base(output_path: str = Form(None)): | |
| """LoRA ์ด๋ํฐ๋ฅผ ๊ธฐ๋ณธ ๋ชจ๋ธ๊ณผ ๋ณํฉ""" | |
| if not LORA_AVAILABLE or lora_manager is None: | |
| return { | |
| "success": False, | |
| "error": "LoRA ๊ธฐ๋ฅ์ด ์ฌ์ฉ ๋ถ๊ฐ๋ฅํฉ๋๋ค. PEFT ๋ผ์ด๋ธ๋ฌ๋ฆฌ๊ฐ ์ค์น๋์ง ์์์ต๋๋ค." | |
| } | |
| try: | |
| success = lora_manager.merge_lora_with_base(output_path) | |
| if success: | |
| return { | |
| "success": True, | |
| "message": "LoRA ์ด๋ํฐ๊ฐ ๊ธฐ๋ณธ ๋ชจ๋ธ๊ณผ ๋ณํฉ๋์์ต๋๋ค.", | |
| "output_path": output_path or f"{lora_manager.base_model_path}_merged" | |
| } | |
| else: | |
| return {"success": False, "error": "LoRA ์ด๋ํฐ ๋ณํฉ์ ์คํจํ์ต๋๋ค."} | |
| except Exception as e: | |
| logger.error(f"โ LoRA ์ด๋ํฐ ๋ณํฉ ์คํจ: {e}") | |
| return {"success": False, "error": str(e)} | |
| # ============================================================================ | |
| # ๋ฉํฐ๋ชจ๋ฌ RAG ์์คํ ์๋ํฌ์ธํธ | |
| # ============================================================================ | |
| async def upload_hybrid_document( | |
| file: UploadFile = File(...), | |
| user_id: str = Form("default_user"), | |
| document_id: Optional[str] = Form(None) | |
| ): | |
| """๋ฉํฐ๋ชจ๋ฌ RAG ๋ฌธ์ ์ ๋ก๋""" | |
| try: | |
| # ํ์ผ ์ ์ฅ | |
| upload_dir = Path("uploads/hybrid_rag") | |
| upload_dir.mkdir(parents=True, exist_ok=True) | |
| if not document_id: | |
| document_id = f"{user_id}_{int(time.time())}_{file.filename}" | |
| file_path = upload_dir / document_id | |
| with open(file_path, "wb") as buffer: | |
| content = await file.read() | |
| buffer.write(content) | |
| # ๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ | |
| result = hybrid_rag_processor.process_document_hybrid(str(file_path), user_id, document_id) | |
| if result["success"]: | |
| # ์ฑ๊ณตํ ์์คํ ์ ๊ณ์ฐ | |
| success_systems = [] | |
| for key, value in result.items(): | |
| if key.endswith('_processing') and value and value.get('success', False): | |
| system_name = key.replace('_processing', '').replace('_', ' ').title() | |
| success_systems.append(system_name) | |
| return DocumentUploadResponse( | |
| success=True, | |
| document_id=document_id, | |
| message=f"๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ ์๋ฃ: {', '.join(success_systems)} ์์คํ ์์ ์ฒ๋ฆฌ๋จ", | |
| chunks=len(success_systems) | |
| ) | |
| else: | |
| return DocumentUploadResponse( | |
| success=False, | |
| error=result.get("error", "๋ฉํฐ๋ชจ๋ฌ ์ฒ๋ฆฌ ์คํจ") | |
| ) | |
| except Exception as e: | |
| logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ๋ฌธ์ ์ ๋ก๋ ์ค๋ฅ: {e}") | |
| return DocumentUploadResponse( | |
| success=False, | |
| error=f"์ ๋ก๋ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}" | |
| ) | |
| async def generate_hybrid_rag_response( | |
| query: str = Form(...), | |
| user_id: str = Form("default_user"), | |
| document_id: str = Form(...), | |
| use_text: bool = Form(True), | |
| use_image: bool = Form(True), | |
| use_latex: bool = Form(True), | |
| use_latex_ocr: bool = Form(False), # LaTeX-OCR ๊ธฐ๋ฅ์ด ๋นํ์ฑํ๋จ | |
| max_length: Optional[int] = Form(None), | |
| temperature: Optional[float] = Form(None), | |
| top_p: Optional[float] = Form(None), | |
| do_sample: Optional[bool] = Form(None) | |
| ): | |
| """๋ฉํฐ๋ชจ๋ฌ RAG ์๋ต ์์ฑ""" | |
| try: | |
| result = hybrid_rag_processor.generate_hybrid_response( | |
| query, user_id, document_id, | |
| use_text, use_image, use_latex, use_latex_ocr, | |
| max_length, temperature, top_p, do_sample | |
| ) | |
| return RAGResponse( | |
| success=result["success"], | |
| response=result["response"], | |
| context=result["context"], | |
| sources=result["sources"], | |
| search_results=result["search_results"], | |
| processing_time=result["processing_time"] | |
| ) | |
| except Exception as e: | |
| logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ์๋ต ์์ฑ ์ค๋ฅ: {e}") | |
| return RAGResponse( | |
| success=False, | |
| response=f"๋ฉํฐ๋ชจ๋ฌ RAG ์๋ต ์์ฑ ์ค ์ค๋ฅ๊ฐ ๋ฐ์ํ์ต๋๋ค: {str(e)}", | |
| context="", | |
| sources=[], | |
| search_results=0, | |
| processing_time=0.0 | |
| ) | |
| async def get_hybrid_document_info(user_id: str, document_id: str): | |
| """๋ฉํฐ๋ชจ๋ฌ RAG ๋ฌธ์ ์ ๋ณด ์กฐํ""" | |
| try: | |
| result = hybrid_rag_processor.get_document_info(user_id, document_id) | |
| return result | |
| except Exception as e: | |
| logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ๋ฌธ์ ์ ๋ณด ์กฐํ ์ค๋ฅ: {e}") | |
| return {"success": False, "error": str(e)} | |
| async def get_hybrid_rag_status(): | |
| """๋ฉํฐ๋ชจ๋ฌ RAG ์์คํ ์ํ ํ์ธ""" | |
| try: | |
| return { | |
| "text_rag_available": True, | |
| "image_rag_available": True, | |
| "latex_rag_available": True, | |
| "latex_ocr_faiss_available": False, # LaTeX-OCR ๊ธฐ๋ฅ์ด ๋นํ์ฑํ๋จ | |
| "status": "ready" | |
| } | |
| except Exception as e: | |
| logger.error(f"๋ฉํฐ๋ชจ๋ฌ RAG ์ํ ํ์ธ ์ค๋ฅ: {e}") | |
| return {"status": "error", "error": str(e)} | |
| # ============================================================================ | |
| # ๐ RAG ์์คํ ๊ณผ ๊ณ ๊ธ ์ปจํ ์คํธ ๊ด๋ฆฌ์ ํตํฉ API | |
| # ============================================================================ | |
| async def rag_query_with_context_integration( | |
| user_id: str = Form(...), | |
| document_id: str = Form(...), | |
| query: str = Form(...), | |
| session_id: str = Form(...), | |
| max_results: int = Form(5), | |
| enable_context_integration: bool = Form(True) | |
| ): | |
| """RAG ์ฟผ๋ฆฌ + ์ปจํ ์คํธ ํตํฉ - ๊ณ ๊ธ ์ปจํ ์คํธ ๊ด๋ฆฌ์์ ์ฐ๋""" | |
| try: | |
| logger.info(f"๐ RAG + ์ปจํ ์คํธ ํตํฉ ์ฟผ๋ฆฌ ์์: ์ฌ์ฉ์ {user_id}, ๋ฌธ์ {document_id}, ์ธ์ {session_id}") | |
| # ์ปจํ ์คํธ ๊ด๋ฆฌ์ ํ์ธ | |
| if not context_manager: | |
| return {"status": "error", "message": "์ปจํ ์คํธ ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค."} | |
| # RAG ์๋ต ์์ฑ (์ปจํ ์คํธ ํตํฉ ํ์ฑํ) | |
| rag_result = rag_processor.generate_rag_response( | |
| user_id=user_id, | |
| document_id=document_id, | |
| query=query, | |
| session_id=session_id if enable_context_integration else None, | |
| context_manager=context_manager if enable_context_integration else None | |
| ) | |
| if not rag_result["success"]: | |
| return rag_result | |
| # ์ปจํ ์คํธ์ RAG ๊ฒฐ๊ณผ ํตํฉ | |
| if enable_context_integration: | |
| try: | |
| # RAG ๊ฒ์ ๊ฒฐ๊ณผ๋ฅผ ์ปจํ ์คํธ์ ์ถ๊ฐ | |
| rag_summary = f"RAG ๊ฒ์ ๊ฒฐ๊ณผ: {query}์ ๋ํ {rag_result.get('search_results', 0)}๊ฐ ๊ด๋ จ ๋ฌธ์ ๋ฐ๊ฒฌ" | |
| # ์ปจํ ์คํธ์ ์์คํ ๋ฉ์์ง๋ก ์ถ๊ฐ | |
| context_manager.add_system_message( | |
| rag_summary, | |
| metadata={"session_id": session_id, "type": "rag_integration", "query": query} | |
| ) | |
| logger.info(f"๐ RAG ๊ฒฐ๊ณผ๋ฅผ ์ปจํ ์คํธ์ ํตํฉ ์๋ฃ (์ธ์ : {session_id})") | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ ์ปจํ ์คํธ ํตํฉ ์คํจ: {e}") | |
| # ํตํฉ๋ ๊ฒฐ๊ณผ ๋ฐํ | |
| result = { | |
| "status": "success", | |
| "rag_response": rag_result, | |
| "context_integration": enable_context_integration, | |
| "session_id": session_id, | |
| "context_summary": context_manager.get_context_summary(session_id) if enable_context_integration else None | |
| } | |
| logger.info(f"โ RAG + ์ปจํ ์คํธ ํตํฉ ์ฟผ๋ฆฌ ์๋ฃ") | |
| return result | |
| except Exception as e: | |
| logger.error(f"โ RAG + ์ปจํ ์คํธ ํตํฉ ์ฟผ๋ฆฌ ์คํจ: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def get_rag_context_summary(session_id: str): | |
| """RAG ํตํฉ ์ปจํ ์คํธ ์์ฝ ์กฐํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "์ปจํ ์คํธ ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค."} | |
| # ์ปจํ ์คํธ ์์ฝ ์ ๋ณด | |
| context_summary = context_manager.get_context_summary(session_id) | |
| # RAG ๊ด๋ จ ์ ๋ณด ์ถ์ถ | |
| rag_contexts = [] | |
| if session_id in context_manager.session_conversations: | |
| for turn in context_manager.session_conversations[session_id]: | |
| if (hasattr(turn, 'metadata') and turn.metadata and | |
| turn.metadata.get('type') == 'rag_integration'): | |
| rag_contexts.append({ | |
| "query": turn.metadata.get('query', ''), | |
| "content": turn.content, | |
| "timestamp": turn.timestamp | |
| }) | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "context_summary": context_summary, | |
| "rag_contexts": rag_contexts, | |
| "rag_context_count": len(rag_contexts) | |
| } | |
| except Exception as e: | |
| logger.error(f"โ RAG ์ปจํ ์คํธ ์์ฝ ์กฐํ ์คํจ: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def clear_rag_context(session_id: str): | |
| """RAG ํตํฉ ์ปจํ ์คํธ ์ ๋ฆฌ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "์ปจํ ์คํธ ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค."} | |
| # RAG ๊ด๋ จ ์ปจํ ์คํธ๋ง ์ ๊ฑฐ | |
| if session_id in context_manager.session_conversations: | |
| conversation_history = context_manager.session_conversations[session_id] | |
| rag_turns = [] | |
| for turn in conversation_history: | |
| if (hasattr(turn, 'metadata') and turn.metadata and | |
| turn.metadata.get('type') == 'rag_integration'): | |
| rag_turns.append(turn) | |
| # RAG ๊ด๋ จ ํด ์ ๊ฑฐ | |
| for turn in rag_turns: | |
| context_manager.remove_message(turn.message_id, session_id) | |
| logger.info(f"๐๏ธ RAG ์ปจํ ์คํธ ์ ๋ฆฌ ์๋ฃ: {len(rag_turns)}๊ฐ ํด ์ ๊ฑฐ (์ธ์ : {session_id})") | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "removed_rag_turns": len(rag_turns), | |
| "message": f"RAG ์ปจํ ์คํธ {len(rag_turns)}๊ฐ ํด์ด ์ ๊ฑฐ๋์์ต๋๋ค." | |
| } | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "removed_rag_turns": 0, | |
| "message": "์ ๊ฑฐํ RAG ์ปจํ ์คํธ๊ฐ ์์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ RAG ์ปจํ ์คํธ ์ ๋ฆฌ ์คํจ: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def get_rag_performance_stats(): | |
| """RAG ์์คํ ์ฑ๋ฅ ํต๊ณ ์กฐํ""" | |
| try: | |
| # RAG ํ๋ก์ธ์ ์ฑ๋ฅ ํต๊ณ | |
| rag_stats = rag_processor.get_performance_stats() | |
| # ๋ฒกํฐ ์คํ ์ด ์ฑ๋ฅ ํต๊ณ | |
| vector_stats = vector_store_manager.get_performance_stats() | |
| # ํตํฉ ์ฑ๋ฅ ํต๊ณ | |
| combined_stats = { | |
| "rag_processor": rag_stats, | |
| "vector_store": vector_stats, | |
| "overall": { | |
| "total_operations": rag_stats.get("total_requests", 0) + vector_stats.get("total_operations", 0), | |
| "success_rate": (rag_stats.get("success_rate", 0.0) + vector_stats.get("success_rate", 0.0)) / 2, | |
| "avg_processing_time": (rag_stats.get("avg_processing_time", 0.0) + vector_stats.get("avg_operation_time", 0.0)) / 2 | |
| }, | |
| "timestamp": time.time() | |
| } | |
| return { | |
| "status": "success", | |
| "performance_stats": combined_stats | |
| } | |
| except Exception as e: | |
| logger.error(f"โ RAG ์ฑ๋ฅ ํต๊ณ ์กฐํ ์คํจ: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def reset_rag_performance_stats(): | |
| """RAG ์์คํ ์ฑ๋ฅ ํต๊ณ ์ด๊ธฐํ""" | |
| try: | |
| # RAG ํ๋ก์ธ์ ํต๊ณ ์ด๊ธฐํ | |
| rag_processor.reset_stats() | |
| # ๋ฒกํฐ ์คํ ์ด ํต๊ณ ์ด๊ธฐํ | |
| vector_store_manager.reset_stats() | |
| logger.info("๐ RAG ์์คํ ์ฑ๋ฅ ํต๊ณ ์ด๊ธฐํ ์๋ฃ") | |
| return { | |
| "status": "success", | |
| "message": "RAG ์์คํ ์ฑ๋ฅ ํต๊ณ๊ฐ ์ด๊ธฐํ๋์์ต๋๋ค." | |
| } | |
| except Exception as e: | |
| logger.error(f"โ RAG ์ฑ๋ฅ ํต๊ณ ์ด๊ธฐํ ์คํจ: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def rag_health_check(): | |
| """RAG ์์คํ ๊ฑด๊ฐ ์ํ ํ์ธ""" | |
| try: | |
| # RAG ํ๋ก์ธ์ ์ํ | |
| rag_status = { | |
| "rag_processor": "healthy", | |
| "enable_context_integration": rag_processor.enable_context_integration, | |
| "max_context_length": rag_processor.max_context_length, | |
| "max_search_results": rag_processor.max_search_results | |
| } | |
| # ๋ฒกํฐ ์คํ ์ด ์ํ | |
| vector_status = vector_store_manager.health_check() | |
| # ๋ฌธ์ ํ๋ก์ธ์ ์ํ | |
| doc_processor_status = { | |
| "status": "healthy", | |
| "supported_formats": document_processor.supported_formats if hasattr(document_processor, 'supported_formats') else [], | |
| "ocr_available": hasattr(document_processor, 'ocr_reader') and document_processor.ocr_reader is not None | |
| } | |
| # ํตํฉ ์ํ | |
| overall_status = "healthy" | |
| if vector_status.get("status") != "healthy": | |
| overall_status = "degraded" | |
| return { | |
| "status": "success", | |
| "overall_status": overall_status, | |
| "rag_processor": rag_status, | |
| "vector_store": vector_status, | |
| "document_processor": doc_processor_status, | |
| "timestamp": time.time() | |
| } | |
| except Exception as e: | |
| logger.error(f"โ RAG ์์คํ ๊ฑด๊ฐ ์ํ ํ์ธ ์คํจ: {e}") | |
| return { | |
| "status": "error", | |
| "overall_status": "unhealthy", | |
| "error": str(e), | |
| "timestamp": time.time() | |
| } | |
| async def batch_process_with_context_integration( | |
| user_id: str = Form(...), | |
| session_id: str = Form(...), | |
| documents: List[UploadFile] = File(...), | |
| enable_context_integration: bool = Form(True) | |
| ): | |
| """๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ + ์ปจํ ์คํธ ํตํฉ""" | |
| try: | |
| logger.info(f"๐ ๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ + ์ปจํ ์คํธ ํตํฉ ์์: ์ฌ์ฉ์ {user_id}, ์ธ์ {session_id}, ๋ฌธ์ {len(documents)}๊ฐ") | |
| results = [] | |
| for i, doc in enumerate(documents): | |
| try: | |
| # ์์ ํ์ผ๋ก ์ ์ฅ | |
| temp_path = f"./temp_{user_id}_{session_id}_{i}_{int(time.time())}" | |
| with open(temp_path, "wb") as f: | |
| f.write(doc.file.read()) | |
| # ๋ฌธ์ ID ์์ฑ | |
| document_id = f"batch_{session_id}_{i}_{int(time.time())}" | |
| # RAG ์ฒ๋ฆฌ | |
| rag_result = rag_processor.process_and_store_document( | |
| user_id=user_id, | |
| document_id=document_id, | |
| file_path=temp_path | |
| ) | |
| # ์ปจํ ์คํธ ํตํฉ | |
| if enable_context_integration and rag_result["success"]: | |
| try: | |
| context_manager.add_system_message( | |
| f"๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ ์๋ฃ: {doc.filename} ({rag_result.get('chunks', 0)}๊ฐ ์ฒญํฌ)", | |
| metadata={"session_id": session_id, "type": "batch_rag", "filename": doc.filename} | |
| ) | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ ์ปจํ ์คํธ ํตํฉ ์คํจ: {e}") | |
| # ์์ ํ์ผ ์ ๋ฆฌ | |
| try: | |
| os.remove(temp_path) | |
| except: | |
| pass | |
| results.append({ | |
| "filename": doc.filename, | |
| "document_id": document_id, | |
| "rag_result": rag_result, | |
| "context_integration": enable_context_integration | |
| }) | |
| except Exception as e: | |
| logger.error(f"โ ๋ฌธ์ {doc.filename} ์ฒ๋ฆฌ ์คํจ: {e}") | |
| results.append({ | |
| "filename": doc.filename, | |
| "error": str(e), | |
| "context_integration": enable_context_integration | |
| }) | |
| # ์ฑ๊ณต/์คํจ ํต๊ณ | |
| success_count = sum(1 for r in results if r.get("rag_result", {}).get("success", False)) | |
| error_count = len(results) - success_count | |
| logger.info(f"โ ๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ ์๋ฃ: {success_count}๊ฐ ์ฑ๊ณต, {error_count}๊ฐ ์คํจ") | |
| return { | |
| "status": "success", | |
| "user_id": user_id, | |
| "session_id": session_id, | |
| "total_documents": len(documents), | |
| "success_count": success_count, | |
| "error_count": error_count, | |
| "results": results, | |
| "context_integration": enable_context_integration | |
| } | |
| except Exception as e: | |
| logger.error(f"โ ๋ฐฐ์น ๋ฌธ์ ์ฒ๋ฆฌ + ์ปจํ ์คํธ ํตํฉ ์คํจ: {e}") | |
| return {"status": "error", "message": str(e)} | |
| async def get_rag_search_history(session_id: str, limit: int = 10): | |
| """RAG ๊ฒ์ ํ์คํ ๋ฆฌ ์กฐํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "์ปจํ ์คํธ ๊ด๋ฆฌ์๋ฅผ ์ฌ์ฉํ ์ ์์ต๋๋ค."} | |
| # RAG ๊ด๋ จ ๊ฒ์ ํ์คํ ๋ฆฌ ์ถ์ถ | |
| search_history = [] | |
| if session_id in context_manager.session_conversations: | |
| for turn in context_manager.session_conversations[session_id]: | |
| if (hasattr(turn, 'metadata') and turn.metadata and | |
| turn.metadata.get('type') in ['rag_integration', 'rag_context', 'batch_rag']): | |
| search_history.append({ | |
| "timestamp": turn.timestamp, | |
| "type": turn.metadata.get('type'), | |
| "query": turn.metadata.get('query', ''), | |
| "filename": turn.metadata.get('filename', ''), | |
| "content": turn.content | |
| }) | |
| # ์ต๊ทผ ์์ผ๋ก ์ ๋ ฌํ๊ณ ์ ํ | |
| search_history.sort(key=lambda x: x['timestamp'], reverse=True) | |
| limited_history = search_history[:limit] | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "search_history": limited_history, | |
| "total_count": len(search_history), | |
| "limited_count": len(limited_history) | |
| } | |
| except Exception as e: | |
| logger.error(f"โ RAG ๊ฒ์ ํ์คํ ๋ฆฌ ์กฐํ ์คํจ: {e}") | |
| return {"status": "error", "message": str(e)} | |
| # ============================================================================ | |
| # ๐ ์ค๋ฌด์ฉ ๊ณ ๊ธ ์ปจํ ์คํธ ๊ด๋ฆฌ์ API ์๋ํฌ์ธํธ | |
| # ============================================================================ | |
| async def get_summary_method(): | |
| """ํ์ฌ ์์ฝ ๋ฐฉ๋ฒ ์กฐํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| return { | |
| "status": "success", | |
| "current_method": context_manager.current_summary_method, | |
| "available_methods": list(context_manager.summary_models.keys()) | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def set_summary_method(method: str = Form(...)): | |
| """์์ฝ ๋ฐฉ๋ฒ ์ค์ """ | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| context_manager.set_summary_method(method) | |
| return { | |
| "status": "success", | |
| "message": f"์์ฝ ๋ฐฉ๋ฒ์ด {method}๋ก ๋ณ๊ฒฝ๋์์ต๋๋ค", | |
| "current_method": context_manager.current_summary_method | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_advanced_summary_stats(session_id: str): | |
| """๊ณ ๊ธ ์์ฝ ํต๊ณ ์กฐํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| summary_stats = context_manager.get_summary_stats(session_id) | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "summary_stats": summary_stats | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_compressed_context(session_id: str, max_tokens: Optional[int] = None): | |
| """์์ถ๋ ์ปจํ ์คํธ ์กฐํ (์์ฝ ํฌํจ)""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| compressed_context = context_manager.get_compressed_context(session_id, max_tokens) | |
| estimated_tokens = context_manager._estimate_tokens(compressed_context) | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "compressed_context": compressed_context, | |
| "estimated_tokens": estimated_tokens, | |
| "context_length": len(compressed_context) | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def force_compression(session_id: str): | |
| """๊ฐ์ ์์ถ ์คํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| # ์์ถ ์ ํต๊ณ | |
| before_stats = context_manager.get_summary_stats(session_id) | |
| # ๊ฐ์ ์์ถ ์คํ | |
| context_manager.force_compression(session_id) | |
| # ์์ถ ํ ํต๊ณ | |
| after_stats = context_manager.get_summary_stats(session_id) | |
| return { | |
| "status": "success", | |
| "message": f"์ธ์ {session_id} ๊ฐ์ ์์ถ ์๋ฃ", | |
| "session_id": session_id, | |
| "before_compression": before_stats, | |
| "after_compression": after_stats, | |
| "compression_effect": { | |
| "summary_reduction": before_stats.get("total_summaries", 0) - after_stats.get("total_summaries", 0), | |
| "token_reduction": before_stats.get("total_tokens", 0) - after_stats.get("total_tokens", 0) | |
| } | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_turn_summaries(session_id: str, limit: int = 10): | |
| """ํด ์์ฝ ๋ชฉ๋ก ์กฐํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| if session_id not in context_manager.turn_summaries: | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "turn_summaries": [], | |
| "total_count": 0 | |
| } | |
| summaries = context_manager.turn_summaries[session_id] | |
| limited_summaries = summaries[-limit:] if limit > 0 else summaries | |
| # TurnSummary ๊ฐ์ฒด๋ฅผ ๋์ ๋๋ฆฌ๋ก ๋ณํ | |
| summary_data = [] | |
| for summary in limited_summaries: | |
| summary_data.append({ | |
| "turn_id": summary.turn_id, | |
| "user_message": summary.user_message, | |
| "assistant_message": summary.assistant_message, | |
| "summary": summary.summary, | |
| "timestamp": summary.timestamp, | |
| "tokens_estimated": summary.tokens_estimated, | |
| "key_topics": summary.key_topics | |
| }) | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "turn_summaries": summary_data, | |
| "total_count": len(summaries), | |
| "limited_count": len(limited_summaries) | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_compression_history(session_id: str): | |
| """์์ถ ํ์คํ ๋ฆฌ ์กฐํ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| if session_id not in context_manager.compression_history: | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "compression_history": [], | |
| "total_compressions": 0 | |
| } | |
| history = context_manager.compression_history[session_id] | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "compression_history": history, | |
| "total_compressions": len(history) | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def get_optimized_context(session_id: str, model_name: str = "default"): | |
| """๋ชจ๋ธ๋ณ ์ต์ ํ๋ ์ปจํ ์คํธ ์กฐํ (์์ฝ ํฌํจ)""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| # ๋ชจ๋ธ๋ณ ์ต์ ํ๋ ์ปจํ ์คํธ ๊ฐ์ ธ์ค๊ธฐ | |
| optimized_context = context_manager.get_context_for_model(model_name, session_id) | |
| estimated_tokens = context_manager._estimate_tokens(optimized_context) | |
| # ์ปจํ ์คํธ ์์ฝ ์ ๋ณด๋ ํจ๊ป ์ ๊ณต | |
| context_summary = context_manager.get_context_summary(session_id) | |
| summary_stats = context_manager.get_summary_stats(session_id) | |
| return { | |
| "status": "success", | |
| "session_id": session_id, | |
| "model_name": model_name, | |
| "optimized_context": optimized_context, | |
| "estimated_tokens": estimated_tokens, | |
| "context_length": len(optimized_context), | |
| "context_summary": context_summary, | |
| "summary_stats": summary_stats | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def export_enhanced_context(session_id: str, file_path: str = Form(None)): | |
| """ํฅ์๋ ์ปจํ ์คํธ ๋ด๋ณด๋ด๊ธฐ (์์ฝ ์ ๋ณด ํฌํจ)""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| exported_path = context_manager.export_context(file_path, session_id) | |
| if exported_path: | |
| return { | |
| "status": "success", | |
| "message": f"์ธ์ {session_id} ํฅ์๋ ์ปจํ ์คํธ ๋ด๋ณด๋ด๊ธฐ ์๋ฃ", | |
| "file_path": exported_path, | |
| "session_id": session_id | |
| } | |
| else: | |
| return {"status": "error", "message": "๋ด๋ณด๋ด๊ธฐ ์คํจ"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def import_enhanced_context(file_path: str = Form(...)): | |
| """ํฅ์๋ ์ปจํ ์คํธ ๊ฐ์ ธ์ค๊ธฐ (์์ฝ ์ ๋ณด ํฌํจ)""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| success = context_manager.import_context(file_path) | |
| if success: | |
| return { | |
| "status": "success", | |
| "message": "ํฅ์๋ ์ปจํ ์คํธ ๊ฐ์ ธ์ค๊ธฐ ์๋ฃ", | |
| "file_path": file_path, | |
| "context_summary": context_manager.get_context_summary("default") | |
| } | |
| else: | |
| return {"status": "error", "message": "๊ฐ์ ธ์ค๊ธฐ ์คํจ"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def advanced_context_health_check(): | |
| """๊ณ ๊ธ ์ปจํ ์คํธ ๊ด๋ฆฌ์ ์ํ ํ์ธ""" | |
| try: | |
| if not context_manager: | |
| return {"status": "error", "message": "Context manager not available"} | |
| # ๊ธฐ๋ณธ ์ํ ํ์ธ | |
| basic_status = { | |
| "context_manager_available": True, | |
| "total_sessions": len(context_manager.session_conversations), | |
| "max_tokens": context_manager.max_tokens, | |
| "max_turns": context_manager.max_turns, | |
| "strategy": context_manager.strategy | |
| } | |
| # ์์ฝ ์์คํ ์ํ ํ์ธ | |
| summary_status = { | |
| "summarization_enabled": context_manager.enable_summarization, | |
| "current_summary_method": context_manager.current_summary_method, | |
| "available_summary_methods": list(context_manager.summary_models.keys()), | |
| "summary_threshold": context_manager.summary_threshold, | |
| "max_summary_tokens": context_manager.max_summary_tokens | |
| } | |
| # ์๋ ์ ๋ฆฌ ์ํ ํ์ธ | |
| cleanup_status = context_manager.get_auto_cleanup_config() | |
| # ์ธ์ ๋ณ ์์ธ ์ ๋ณด | |
| session_details = {} | |
| for session_id in context_manager.session_conversations.keys(): | |
| session_details[session_id] = { | |
| "turns": len(context_manager.session_conversations[session_id]), | |
| "turn_summaries": len(context_manager.turn_summaries.get(session_id, [])), | |
| "compression_history": len(context_manager.compression_history.get(session_id, [])), | |
| "context_summary": context_manager.get_context_summary(session_id), | |
| "summary_stats": context_manager.get_summary_stats(session_id) | |
| } | |
| return { | |
| "status": "success", | |
| "basic_status": basic_status, | |
| "summary_status": summary_status, | |
| "cleanup_status": cleanup_status, | |
| "session_details": session_details, | |
| "timestamp": time.time() | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| "" | |
| # ============================================================================ | |
| # ์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ค์ ๊ด๋ฆฌ API | |
| # ============================================================================ | |
| async def get_user_memory_settings(user_id: str): | |
| """์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ค์ ์กฐํ""" | |
| try: | |
| from lily_llm_core.user_memory_manager import user_memory_manager | |
| # ๊ธฐ๋ณธ ์ค์ ์กฐํ | |
| keep_memory = user_memory_manager.get_memory_setting(user_id, "keep_memory_on_room_change") | |
| return { | |
| "status": "success", | |
| "user_id": user_id, | |
| "settings": { | |
| "keep_memory_on_room_change": keep_memory if keep_memory is not None else True | |
| } | |
| } | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def update_user_memory_settings( | |
| user_id: str, | |
| keep_memory_on_room_change: bool = Form(True) | |
| ): | |
| """์ฌ์ฉ์ ๋ฉ๋ชจ๋ฆฌ ์ค์ ์ ๋ฐ์ดํธ""" | |
| try: | |
| from lily_llm_core.user_memory_manager import user_memory_manager | |
| # ์ค์ ์ ๋ฐ์ดํธ | |
| success = user_memory_manager.update_memory_setting( | |
| user_id, "keep_memory_on_room_change", keep_memory_on_room_change | |
| ) | |
| if success: | |
| return { | |
| "status": "success", | |
| "message": f"์ฌ์ฉ์ {user_id} ๋ฉ๋ชจ๋ฆฌ ์ค์ ์ ๋ฐ์ดํธ ์๋ฃ", | |
| "settings": { | |
| "keep_memory_on_room_change": keep_memory_on_room_change | |
| } | |
| } | |
| else: | |
| return {"status": "error", "message": "์ค์ ์ ๋ฐ์ดํธ ์คํจ"} | |
| except Exception as e: | |
| return {"status": "error", "message": str(e)} | |
| async def handle_room_change(user_id: str, new_room_id: str = Form(...)): | |
| """Room ๋ณ๊ฒฝ ์ ๋ฉ๋ชจ๋ฆฌ ์ฒ๋ฆฌ""" | |
| try: | |
| from lily_llm_core.user_memory_manager import user_memory_manager | |
| from lily_llm_core.integrated_memory_manager import integrated_memory_manager | |
| # ์ฌ์ฉ์ ์ค์ ํ์ธ | |
| keep_memory = user_memory_manager.get_memory_setting(user_id, "keep_memory_on_room_change") | |
| if keep_memory: | |
| # ๋ฉ๋ชจ๋ฆฌ ์ ์ง (๊ธฐ๋ณธ ๋์) | |
| logger.info(f"๐ ์ฌ์ฉ์ {user_id}๊ฐ room {new_room_id}๋ก ์ด๋ - ๋ฉ๋ชจ๋ฆฌ ์ ์ง") | |
| return { | |
| "status": "success", | |
| "message": f"Room {new_room_id}๋ก ์ด๋ - ๋ฉ๋ชจ๋ฆฌ ์ ์ง๋จ", | |
| "memory_preserved": True | |
| } | |
| else: | |
| # ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ | |
| logger.info(f"๐ ์ฌ์ฉ์ {user_id}๊ฐ room {new_room_id}๋ก ์ด๋ - ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ") | |
| # ์ธ์ ์ปจํ ์คํธ ์ด๊ธฐํ | |
| if context_manager: | |
| # ์ฌ์ฉ์ ๊ด๋ จ ์ธ์ ๋ค ์ฐพ์์ ์ด๊ธฐํ | |
| user_sessions = [ | |
| session_id for session_id in context_manager.session_conversations.keys() | |
| if f"user_{user_id}" in session_id | |
| ] | |
| for session_id in user_sessions: | |
| context_manager.clear_session_context(session_id) | |
| logger.info(f"๐๏ธ ์ธ์ ์ปจํ ์คํธ ์ด๊ธฐํ: {session_id}") | |
| # Room ์ปจํ ์คํธ ์ด๊ธฐํ (์ฌ์ฉ์ ๊ด๋ จ ๋ฌธ์ ์ ๊ฑฐ) | |
| try: | |
| room_context = integrated_memory_manager.room_context_manager.get_room_context(new_room_id) | |
| if room_context and room_context.documents: | |
| # ์ฌ์ฉ์๊ฐ ์ ๋ก๋ํ ๋ฌธ์๋ค ์ ๊ฑฐ | |
| original_count = len(room_context.documents) | |
| room_context.documents = [ | |
| doc for doc in room_context.documents | |
| if (isinstance(doc, dict) and doc.get('uploaded_by') != user_id) or | |
| (hasattr(doc, 'uploaded_by') and getattr(doc, 'uploaded_by') != user_id) | |
| ] | |
| # ๋ณ๊ฒฝ์ฌํญ ์ ์ฅ | |
| integrated_memory_manager.room_context_manager.save_room_context(new_room_id, room_context) | |
| removed_count = original_count - len(room_context.documents) | |
| logger.info(f"๏ฟฝ๏ฟฝ๏ธ Room {new_room_id}์์ ์ฌ์ฉ์ {user_id} ๋ฌธ์ {removed_count}๊ฐ ์ ๊ฑฐ") | |
| except Exception as e: | |
| logger.warning(f"โ ๏ธ Room ์ปจํ ์คํธ ์ด๊ธฐํ ์คํจ: {e}") | |
| return { | |
| "status": "success", | |
| "message": f"Room {new_room_id}๋ก ์ด๋ - ๋ฉ๋ชจ๋ฆฌ ์ด๊ธฐํ๋จ", | |
| "memory_preserved": False, | |
| "context_cleared": True | |
| } | |
| except Exception as e: | |
| logger.error(f"โ Room ๋ณ๊ฒฝ ์ฒ๋ฆฌ ์คํจ: {e}") | |
| return {"status": "error", "message": str(e)} |