File size: 3,865 Bytes
2844db6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from fastapi import FastAPI, HTTPException
import time
from typing import List, Dict
import logging

# ๋‚ด๋ถ€ ๋ชจ๋“ˆ ์ž„ํฌํŠธ
from .budget_packer import enhanced_greedy_pack
from .cross_encoder import SafeCrossEncoderManager
from .capsule_logger import ExplainCapsuleLogger

# --- FastAPI ์•ฑ ๋ฐ ์ฃผ์š” ์ปดํฌ๋„ŒํŠธ ์ดˆ๊ธฐํ™” ---

app = FastAPI(
    title="CRoM-EfficientLLM Server",
    description="Context Reranking and Management for Efficient LLMs",
    version="1.0.1"
)

logging.basicConfig(level=logging.INFO)

# ์ปดํฌ๋„ŒํŠธ ์ธ์Šคํ„ด์Šคํ™”
# TODO: ์„ค์ • ํŒŒ์ผ(config.yaml)์—์„œ ๋ชจ๋ธ ์ด๋ฆ„ ๋“ฑ์„ ๋กœ๋“œํ•˜๋„๋ก ๊ฐœ์„  ํ•„์š”
ce_manager = SafeCrossEncoderManager(model_name="ms-marco-TinyBERT-L-2-v2")
capsule_logger = ExplainCapsuleLogger(log_directory="artifacts/logs")


# --- ์‘๋‹ต ์Šคํ‚ค๋งˆ ๋ฐ ํ—ฌํผ ํ•จ์ˆ˜ ---

class ProcessResponseV2:
    """ํ™•์žฅ๋œ /process ์—”๋“œํฌ์ธํŠธ ์‘๋‹ต ์Šคํ‚ค๋งˆ ํ—ฌํผ"""
    
    @staticmethod
    def create_response(query: str, packed_chunks: List[Dict], 
                       processing_stats: Dict, cross_encoder_status: str, 
                       processing_time: float) -> Dict:
        """๊ฐœ์„ ๋œ ์‘๋‹ต ์ƒ์„ฑ"""
        
        response = {
            "success": True,
            "query": query,
            "chunks": packed_chunks,
            "stats": processing_stats, # packing ํ†ต๊ณ„
            "meta": {
                "cross_encoder_status": cross_encoder_status,
                "processing_time_ms": processing_time * 1000,
                "timestamp": time.time()
            }
        }
        return response

# --- API ์—”๋“œํฌ์ธํŠธ ์ •์˜ ---

@app.post("/process", summary="Rerank and pack text chunks")
def process_chunks(query: str, chunks: List[Dict], budget: int = 4096):
    """
    ์ฃผ์–ด์ง„ ์ฟผ๋ฆฌ์™€ ์ฒญํฌ ๋ชฉ๋ก์„ ๋ฆฌ๋žญํ‚นํ•˜๊ณ  ์˜ˆ์‚ฐ์— ๋งž๊ฒŒ ํŒจํ‚นํ•ฉ๋‹ˆ๋‹ค.
    """
    start_time = time.time()

    try:
        # 1. Cross-Encoder๋กœ ๋ฆฌ๋žญํ‚น (ํ™œ์„ฑํ™” ์‹œ)
        doc_texts = [chunk.get("text", "") for chunk in chunks]
        scores = ce_manager.rerank(query, doc_texts)
        for chunk, score in zip(chunks, scores):
            chunk["score"] = score

        # 2. ์˜ˆ์‚ฐ์— ๋งž๊ฒŒ ํŒจํ‚น
        packed_chunks, stats = enhanced_greedy_pack(chunks, budget=budget, score_key="score")

        # 3. ์ตœ์ข… ์‘๋‹ต ์ƒ์„ฑ
        processing_time = time.time() - start_time
        response_data = ProcessResponseV2.create_response(
            query=query,
            packed_chunks=packed_chunks,
            processing_stats=stats,
            cross_encoder_status=ce_manager.get_status_for_response(),
            processing_time=processing_time
        )

        # 4. ์„ค๋ช… ์บก์А ๋กœ๊น…
        capsule = capsule_logger.create_explain_capsule(
            query=query,
            response_data=response_data,
            processing_stats=stats,
            cross_encoder_status=ce_manager.get_status_for_response()
        )
        capsule_logger.log_capsule(capsule)

        return response_data

    except Exception as e:
        logging.error(f"Error during /process: {e}", exc_info=True)
        # ์˜ค๋ฅ˜ ๋กœ๊น…
        capsule_logger.log_error({
            "endpoint": "/process",
            "error": str(e),
            "query": query,
        })
        raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")

@app.get("/healthz", summary="Health check")
def health_check():
    """์„œ๋ฒ„์˜ ์ƒํƒœ๋ฅผ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค."""
    return {"status": "ok", "cross_encoder": ce_manager.get_status_for_response()}

@app.get("/metrics", summary="Get Prometheus metrics")
def get_metrics():
    """Prometheus ๋ฉ”ํŠธ๋ฆญ์„ ๋…ธ์ถœํ•ฉ๋‹ˆ๋‹ค."""
    # TODO: Prometheus-client๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‹ค์ œ ๋ฉ”ํŠธ๋ฆญ์„ ๊ตฌํ˜„ํ•ด์•ผ ํ•จ
    return {"message": "Metrics endpoint is active. Implement with prometheus-client."}