Flamehaven's picture
feat: Implement core CRoM modules (packer, encoder, logger, server)
2844db6
from fastapi import FastAPI, HTTPException
import time
from typing import List, Dict
import logging
# ๋‚ด๋ถ€ ๋ชจ๋“ˆ ์ž„ํฌํŠธ
from .budget_packer import enhanced_greedy_pack
from .cross_encoder import SafeCrossEncoderManager
from .capsule_logger import ExplainCapsuleLogger
# --- FastAPI ์•ฑ ๋ฐ ์ฃผ์š” ์ปดํฌ๋„ŒํŠธ ์ดˆ๊ธฐํ™” ---
app = FastAPI(
title="CRoM-EfficientLLM Server",
description="Context Reranking and Management for Efficient LLMs",
version="1.0.1"
)
logging.basicConfig(level=logging.INFO)
# ์ปดํฌ๋„ŒํŠธ ์ธ์Šคํ„ด์Šคํ™”
# TODO: ์„ค์ • ํŒŒ์ผ(config.yaml)์—์„œ ๋ชจ๋ธ ์ด๋ฆ„ ๋“ฑ์„ ๋กœ๋“œํ•˜๋„๋ก ๊ฐœ์„  ํ•„์š”
ce_manager = SafeCrossEncoderManager(model_name="ms-marco-TinyBERT-L-2-v2")
capsule_logger = ExplainCapsuleLogger(log_directory="artifacts/logs")
# --- ์‘๋‹ต ์Šคํ‚ค๋งˆ ๋ฐ ํ—ฌํผ ํ•จ์ˆ˜ ---
class ProcessResponseV2:
"""ํ™•์žฅ๋œ /process ์—”๋“œํฌ์ธํŠธ ์‘๋‹ต ์Šคํ‚ค๋งˆ ํ—ฌํผ"""
@staticmethod
def create_response(query: str, packed_chunks: List[Dict],
processing_stats: Dict, cross_encoder_status: str,
processing_time: float) -> Dict:
"""๊ฐœ์„ ๋œ ์‘๋‹ต ์ƒ์„ฑ"""
response = {
"success": True,
"query": query,
"chunks": packed_chunks,
"stats": processing_stats, # packing ํ†ต๊ณ„
"meta": {
"cross_encoder_status": cross_encoder_status,
"processing_time_ms": processing_time * 1000,
"timestamp": time.time()
}
}
return response
# --- API ์—”๋“œํฌ์ธํŠธ ์ •์˜ ---
@app.post("/process", summary="Rerank and pack text chunks")
def process_chunks(query: str, chunks: List[Dict], budget: int = 4096):
"""
์ฃผ์–ด์ง„ ์ฟผ๋ฆฌ์™€ ์ฒญํฌ ๋ชฉ๋ก์„ ๋ฆฌ๋žญํ‚นํ•˜๊ณ  ์˜ˆ์‚ฐ์— ๋งž๊ฒŒ ํŒจํ‚นํ•ฉ๋‹ˆ๋‹ค.
"""
start_time = time.time()
try:
# 1. Cross-Encoder๋กœ ๋ฆฌ๋žญํ‚น (ํ™œ์„ฑํ™” ์‹œ)
doc_texts = [chunk.get("text", "") for chunk in chunks]
scores = ce_manager.rerank(query, doc_texts)
for chunk, score in zip(chunks, scores):
chunk["score"] = score
# 2. ์˜ˆ์‚ฐ์— ๋งž๊ฒŒ ํŒจํ‚น
packed_chunks, stats = enhanced_greedy_pack(chunks, budget=budget, score_key="score")
# 3. ์ตœ์ข… ์‘๋‹ต ์ƒ์„ฑ
processing_time = time.time() - start_time
response_data = ProcessResponseV2.create_response(
query=query,
packed_chunks=packed_chunks,
processing_stats=stats,
cross_encoder_status=ce_manager.get_status_for_response(),
processing_time=processing_time
)
# 4. ์„ค๋ช… ์บก์А ๋กœ๊น…
capsule = capsule_logger.create_explain_capsule(
query=query,
response_data=response_data,
processing_stats=stats,
cross_encoder_status=ce_manager.get_status_for_response()
)
capsule_logger.log_capsule(capsule)
return response_data
except Exception as e:
logging.error(f"Error during /process: {e}", exc_info=True)
# ์˜ค๋ฅ˜ ๋กœ๊น…
capsule_logger.log_error({
"endpoint": "/process",
"error": str(e),
"query": query,
})
raise HTTPException(status_code=500, detail=f"Internal Server Error: {e}")
@app.get("/healthz", summary="Health check")
def health_check():
"""์„œ๋ฒ„์˜ ์ƒํƒœ๋ฅผ ํ™•์ธํ•ฉ๋‹ˆ๋‹ค."""
return {"status": "ok", "cross_encoder": ce_manager.get_status_for_response()}
@app.get("/metrics", summary="Get Prometheus metrics")
def get_metrics():
"""Prometheus ๋ฉ”ํŠธ๋ฆญ์„ ๋…ธ์ถœํ•ฉ๋‹ˆ๋‹ค."""
# TODO: Prometheus-client๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์‹ค์ œ ๋ฉ”ํŠธ๋ฆญ์„ ๊ตฌํ˜„ํ•ด์•ผ ํ•จ
return {"message": "Metrics endpoint is active. Implement with prometheus-client."}