File size: 2,419 Bytes
1591012
c9abf3f
 
 
 
 
 
 
a8e2a77
 
331e3ee
1591012
c9abf3f
1591012
c9abf3f
 
 
 
 
 
 
 
 
 
 
 
 
 
1591012
 
c9abf3f
 
1591012
c9abf3f
 
1591012
c9abf3f
 
 
1591012
 
c9abf3f
1591012
 
 
 
 
 
 
c9abf3f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a8e2a77
 
 
 
 
 
331e3ee
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from fastapi import Depends, HTTPException, status, Query
from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy import select
from jose import JWTError, jwt
from app.database import async_session_maker
from app.models import User
from app.config import settings
from fastapi import Request
from chromadb import AsyncHttpClient
from chromadb.api.models.Collection import Collection
from typing import Optional

security = HTTPBearer(auto_error=False)

async def get_db():
    async with async_session_maker() as session:
        try:
            yield session
            await session.commit()
        except Exception:
            await session.rollback()
            raise
        finally:
            await session.close()


async def get_current_user(
        credentials: Optional[HTTPAuthorizationCredentials] = Depends(security),
        token_query: Optional[str] = Query(None, alias="token"),
        db: AsyncSession = Depends(get_db)
) -> User:
    
    credentials_exception = HTTPException(
        status_code=status.HTTP_401_UNAUTHORIZED, 
        detail="Could not validate credentials",
        headers={"WWW-Authenticate": "Bearer"},
    )

    token = None
    if credentials:
        token = credentials.credentials
    elif token_query:
        token = token_query
        
    if not token:
        raise credentials_exception

    try:
        payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[settings.ALGORITHM])
        username: str = payload.get("sub")
        if username is None:
            raise credentials_exception
    except JWTError:
        raise credentials_exception
    
    result = await db.execute(select(User).filter(User.username == username))
    user = result.scalar_one_or_none()

    if user is None:
        raise credentials_exception
    
    return user



async def get_chroma_client(request: Request) -> AsyncHttpClient:
    client = getattr(request.app.state, "chroma_client", None)
    if client is None:
        raise RuntimeError("ChromaDB client is not initialized in App State.")
    return client

def get_chroma_collection(request: Request) -> Collection:
    collection = getattr(request.app.state, "chroma_collection", None)
    if collection is None:
        raise RuntimeError("ChromaDB Collection not loaded during application startup.")
    return collection