File size: 3,865 Bytes
2844db6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 | from fastapi import FastAPI, HTTPException
import time
from typing import List, Dict
import logging
# ๋ด๋ถ ๋ชจ๋ ์ํฌํธ
from .budget_packer import enhanced_greedy_pack
from .cross_encoder import SafeCrossEncoderManager
from .capsule_logger import ExplainCapsuleLogger
# --- FastAPI ์ฑ ๋ฐ ์ฃผ์ ์ปดํฌ๋ํธ ์ด๊ธฐํ ---
app = FastAPI(
title="CRoM-EfficientLLM Server",
description="Context Reranking and Management for Efficient LLMs",
version="1.0.1"
)
logging.basicConfig(level=logging.INFO)
# ์ปดํฌ๋ํธ ์ธ์คํด์คํ
# TODO: ์ค์ ํ์ผ(config.yaml)์์ ๋ชจ๋ธ ์ด๋ฆ ๋ฑ์ ๋ก๋ํ๋๋ก ๊ฐ์ ํ์
ce_manager = SafeCrossEncoderManager(model_name="ms-marco-TinyBERT-L-2-v2")
capsule_logger = ExplainCapsuleLogger(log_directory="artifacts/logs")
# --- ์๋ต ์คํค๋ง ๋ฐ ํฌํผ ํจ์ ---
class ProcessResponseV2:
"""ํ์ฅ๋ /process ์๋ํฌ์ธํธ ์๋ต ์คํค๋ง ํฌํผ"""
@staticmethod
def create_response(query: str, packed_chunks: List[Dict],
processing_stats: Dict, cross_encoder_status: str,
processing_time: float) -> Dict:
"""๊ฐ์ ๋ ์๋ต ์์ฑ"""
response = {
"success": True,
"query": query,
"chunks": packed_chunks,
"stats": processing_stats, # packing ํต๊ณ
"meta": {
"cross_encoder_status": cross_encoder_status,
"processing_time_ms": processing_time * 1000,
"timestamp": time.time()
}
}
return response
# --- API ์๋ํฌ์ธํธ ์ ์ ---
@app.post("/process", summary="Rerank and pack text chunks")
def process_chunks(query: str, chunks: List[Dict], budget: int = 4096):
"""
์ฃผ์ด์ง ์ฟผ๋ฆฌ์ ์ฒญํฌ ๋ชฉ๋ก์ ๋ฆฌ๋ญํนํ๊ณ ์์ฐ์ ๋ง๊ฒ ํจํนํฉ๋๋ค.
"""
start_time = time.time()
try:
# 1. Cross-Encoder๋ก ๋ฆฌ๋ญํน (ํ์ฑํ ์)
doc_texts = [chunk.get("text", "") for chunk in chunks]
scores = ce_manager.rerank(query, doc_texts)
for chunk, score in zip(chunks, scores):
chunk["score"] = score
# 2. ์์ฐ์ ๋ง๊ฒ ํจํน
packed_chunks, stats = enhanced_greedy_pack(chunks, budget=budget, score_key="score")
# 3. ์ต์ข
์๋ต ์์ฑ
processing_time = time.time() - start_time
response_data = ProcessResponseV2.create_response(
query=query,
packed_chunks=packed_chunks,
processing_stats=stats,
cross_encoder_status=ce_manager.get_status_for_response(),
processing_time=processing_time
)
# 4. ์ค๋ช
์บก์ ๋ก๊น
capsule = capsule_logger.create_explain_capsule(
query=query,
response_data=response_data,
processing_stats=stats,
cross_encoder_status=ce_manager.get_status_for_response()
)
capsule_logger.log_capsule(capsule)
return response_data
except Exception as e:
logging.error(f"Error during /process: {e}", exc_info=True)
# ์ค๋ฅ ๋ก๊น
capsule_logger.log_error({
"endpoint": "/process",
"error": str(e),
"query": query,
})
raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
@app.get("/healthz", summary="Health check")
def health_check():
"""์๋ฒ์ ์ํ๋ฅผ ํ์ธํฉ๋๋ค."""
return {"status": "ok", "cross_encoder": ce_manager.get_status_for_response()}
@app.get("/metrics", summary="Get Prometheus metrics")
def get_metrics():
"""Prometheus ๋ฉํธ๋ฆญ์ ๋
ธ์ถํฉ๋๋ค."""
# TODO: Prometheus-client๋ฅผ ์ฌ์ฉํ์ฌ ์ค์ ๋ฉํธ๋ฆญ์ ๊ตฌํํด์ผ ํจ
return {"message": "Metrics endpoint is active. Implement with prometheus-client."}
|