from contextlib import asynccontextmanager from fastapi import FastAPI, Depends, HTTPException from pydantic import BaseModel import uvicorn import asyncpg #import torch from typing import List, Dict, Any, Union import os import numpy as np # ONNX 및 HuggingFace 관련 임포트 from huggingface_hub import hf_hub_download import onnxruntime as ort from transformers import AutoTokenizer from database_conn import connect_to_db, close_db_connection, get_db_connection from test_router import router as test_router HF_TOKEN = os.getenv("HUGGING_FACE_HUB_TOKEN") # --- Wrapper Class 정의 --- # 기존 SentenceTransformer와 동일한 메서드(encode_document, encode_query, similarity)를 제공 class OnnxGemmaWrapper: def __init__(self, model_id, token=None): print(f"Loading ONNX model: {model_id}...") self.tokenizer = AutoTokenizer.from_pretrained(model_id, token=token) # ONNX 모델 및 가중치 다운로드 model_path = hf_hub_download(model_id, subfolder="onnx", filename="model.onnx", token=token) hf_hub_download(model_id, subfolder="onnx", filename="model.onnx_data", token=token) # 추론 세션 생성 (GPU 사용 가능 시 CUDAProvider 사용, 없으면 CPU) available_providers = ort.get_available_providers() if 'CUDAExecutionProvider' in available_providers: print("CUDA detected. Using GPU.") providers = ['CUDAExecutionProvider', 'CPUExecutionProvider'] else: print("CUDA not detected. Using CPU.") providers = ['CPUExecutionProvider'] self.session = ort.InferenceSession(model_path, providers=providers) # Prefix 정의 self.prefixes = { "query": "task: search result | query: ", "document": "title: none | text: ", } print("ONNX Model loaded successfully.") def _run_inference(self, texts: List[str]): inputs = self.tokenizer(texts, padding=True, truncation=True, return_tensors="np") # ONNX Runtime 실행 (output[0]: last_hidden_state, output[1]: pooler_output or sentence_embedding) # EmbeddingGemma ONNX 모델은 보통 두 번째 리턴값이 sentence embedding입니다. outputs = self.session.run(None, dict(inputs)) # outputs[1]이 (Batch, 768) 형태의 임베딩 return outputs[1] def encode_document(self, documents: List[str]) -> np.ndarray: # 문서용 Prefix 추가 prefixed_docs = [self.prefixes["document"] + doc for doc in documents] return self._run_inference(prefixed_docs) def encode_query(self, query: str) -> np.ndarray: # 쿼리용 Prefix 추가 (단일 쿼리도 리스트로 처리) prefixed_query = [self.prefixes["query"] + query] # 결과는 (1, 768) 형태이므로 첫 번째 요소를 반환하여 (768,)로 맞출 수도 있으나, # 기존 로직과의 호환성을 위해 배치 차원을 유지하거나 필요 시 조정. # 여기서는 (1, 768) 형태로 반환합니다. return self._run_inference(prefixed_query)[0] def similarity(self, query_emb: np.ndarray, doc_embs: np.ndarray) -> np.ndarray: # 코사인 유사도 계산 (Dot Product) # query_emb: (768,) 또는 (1, 768) # doc_embs: (N, 768) # 차원 맞추기 (query_emb가 1차원이면 2차원으로 변환) if query_emb.ndim == 1: query_emb = query_emb.reshape(1, -1) # Dot Product 수행 (@ 연산자) scores = query_emb @ doc_embs.T # 결과가 (1, N) 형태이므로 1차원 배열 (N,)으로 변환하여 반환 return scores.flatten() # 전역 변수 초기화 model = None @asynccontextmanager async def lifespan(app: FastAPI): global model try: await connect_to_db() except Exception as e: print(f"!!! [Startup] DB Connection FAILED: {e!r}") # --- 모델 로드 --- try: model = OnnxGemmaWrapper( model_id="onnx-community/embeddinggemma-300m-ONNX", token=HF_TOKEN ) except Exception as e: print(f"Error loading ONNX model: {e}") model = None yield await close_db_connection() print(">>> [Shutdown] FastAPI Server graceful shutdown complete.") app = FastAPI( title="Gemma Embedding Service (ONNX)", description="Implements text embedding generation via REST API using ONNX Runtime.", version="1.0.0", lifespan=lifespan ) # 3. 루트 엔드포인트 (GET /) @app.get("/") def read_root(): result={"success":True,"data":None,"msg":""} try: result["data"]="ok" return result except Exception as e: result["success"] = False result["msg"]=f"server error. {e!r}" return result app.include_router(test_router, prefix="/api/test") class Item(BaseModel): name: str price: float is_offer: bool | None = None class MakeTextEmbedding(BaseModel): query: str documents: List[str] class EmbeddingOutput(BaseModel): success: bool msg: str data: Union[List[List[float]], None] = None @app.post("/make_text_embedding", summary="Calculate semantic similarity and find the best match") async def calculate_similarity(data: MakeTextEmbedding): result={"success":True,"data":None,"msg":""} try: if model is None: result["success"] = False result["msg"]="Model not loaded. Service is unavailable." return result # ONNX Runtime은 내부적으로 최적화되어 있으므로 torch.no_grad() 불필요하지만, # 기존 흐름상 그냥 둬도 상관없거나 제거해도 됩니다. 여기선 제거합니다. # Encode documents document_embeddings = model.encode_document(data.documents) # numpy array -> list 변환 embeddings_list = document_embeddings.tolist() result["data"] = embeddings_list result["msg"] = f"document_embeddings.shape:{document_embeddings.shape}" return result except Exception as e: result["success"] = False result["msg"] = f"server error. {e!r}" return result @app.post("/string_distance_compare", summary="Calculate semantic similarity and find the best match") async def string_distance_compare(data: MakeTextEmbedding): result={"success":True,"data":None,"msg":""} try: if model is None: result["success"] = False result["msg"]="Model not loaded. Service is unavailable." return result # Encode query and documents query_embeddings = model.encode_query(data.query) document_embeddings = model.encode_document(data.documents) # Calculate similarity similarities = model.similarity(query_embeddings, document_embeddings) result["data"] = similarities.tolist() result["msg"] = f"query_embeddings.shape: {query_embeddings.shape}, document_embeddings.shape: {document_embeddings.shape}" return result except Exception as e: result["success"] = False result["msg"] = f"server error. {e!r}" return result # ---------------------------------------------------- # DB 관련 엔드포인트 및 기타 API는 기존 유지 # ---------------------------------------------------- @app.get("/time", response_model=Dict[str, Any]) async def get_db_time(conn: asyncpg.Connection = Depends(get_db_connection)): result = {"success": True, "data": None, "msg": ""} try: query = "SELECT NOW();" records = await conn.fetch(query) data_list_of_dicts: List[Dict[str, Any]] = [dict(record) for record in records] result["data"] = data_list_of_dicts[0] except Exception as e: result["success"] = False result["msg"] = f"Database query error: {e!r}" return result @app.get("/items") def read_item(q: str | None = None): return {"q": q, "description": "This is a query test."} @app.post("/items/") def create_item(item: Item): if item.price > 100.0: item.name = f"Premium {item.name}" return {"message": "Item created successfully", "item_data": item} if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True)