| import os |
| import shutil |
| import tempfile |
| from typing import Optional |
| from uuid import UUID |
|
|
| from fastapi import APIRouter, Depends, File, Form, HTTPException, UploadFile |
| from sqlalchemy import text |
| from sqlmodel.ext.asyncio.session import AsyncSession |
|
|
| from src.auth.utils import get_current_user |
| from src.core.database import get_async_session |
| from .schemas import ManualTextRequest |
| from .service import store_manual_text |
| from .embedding import embedding_model |
| from .schemas import ( |
| SemanticSearchRequest, |
| SemanticSearchResult, |
| TokenizeRequest, |
| TokenizeResponse, |
| UploadKBResponse, |
| ) |
| from .service import process_pdf_and_store |
|
|
| router = APIRouter(prefix="/chatbot", tags=["chatbot"]) |
|
|
| @router.post("/tokenize", response_model=TokenizeResponse) |
| async def tokenize_text(payload: TokenizeRequest,user_id: UUID = Depends(get_current_user)): |
| try: |
| encoded = embedding_model.tokenizer( |
| payload.text, |
| return_tensors="np", |
| truncation=True, |
| padding="longest", |
| max_length=512, |
| ) |
|
|
| return TokenizeResponse( |
| input_ids=encoded["input_ids"][0].tolist(), |
| attention_mask=encoded["attention_mask"][0].tolist(), |
| ) |
|
|
| except Exception as e: |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
| @router.post("/semantic-search", response_model=list[SemanticSearchResult]) |
| async def semantic_search( |
| payload: SemanticSearchRequest, session: AsyncSession = Depends(get_async_session), user_id: UUID = Depends(get_current_user) |
| ): |
|
|
| if len(payload.embedding) == 0: |
| raise HTTPException(status_code=400, detail="Embedding cannot be empty.") |
|
|
| q_vector = payload.embedding |
| top_k = payload.top_k or 3 |
|
|
| q_vector_str = "[" + ",".join(str(x) for x in q_vector) + "]" |
|
|
| sql = text( |
| """ |
| SELECT id, kb_id, chunk_text,image_url, |
| embedding <#> :query_vec AS score |
| FROM knowledge_chunk |
| ORDER BY embedding <#> :query_vec ASC |
| LIMIT :top_k |
| """ |
| ) |
|
|
| result = await session.execute( |
| sql, {"query_vec": q_vector_str, "top_k": top_k} |
| ) |
| rows = result.fetchall() |
| |
| return [ |
| SemanticSearchResult( |
| chunk_id=str(r.id), |
| kb_id=str(r.kb_id), |
| text=r.chunk_text, |
| image_url=r.image_url, |
| score=float(r.score), |
| ) |
| for r in rows |
| ] |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|