deploy FastAPI backend
Browse files- app/__pycache__/config.cpython-311.pyc +0 -0
- app/__pycache__/logging_config.cpython-311.pyc +0 -0
- app/__pycache__/main.cpython-311.pyc +0 -0
- app/auth/__pycache__/jwt_handler.cpython-311.pyc +0 -0
- app/auth/__pycache__/models.cpython-311.pyc +0 -0
- app/auth/__pycache__/routes.cpython-311.pyc +0 -0
- app/auth/jwt_handler.py +57 -0
- app/auth/models.py +28 -0
- app/auth/routes.py +145 -0
- app/config.py +17 -0
- app/database/__pycache__/connection.cpython-311.pyc +0 -0
- app/database/__pycache__/schemas.cpython-311.pyc +0 -0
- app/database/connection.py +7 -0
- app/database/schemas.py +73 -0
- app/logging_config.py +50 -0
- app/main.py +24 -0
- app/rag/__pycache__/models.cpython-311.pyc +0 -0
- app/rag/__pycache__/rag_processor.cpython-311.pyc +0 -0
- app/rag/__pycache__/routes.cpython-311.pyc +0 -0
- app/rag/models.py +39 -0
- app/rag/rag_processor.py +151 -0
- app/rag/routes.py +168 -0
- app/request.py +142 -0
- requirements.txt +17 -0
app/__pycache__/config.cpython-311.pyc
ADDED
|
Binary file (1.66 kB). View file
|
|
|
app/__pycache__/logging_config.cpython-311.pyc
ADDED
|
Binary file (2.65 kB). View file
|
|
|
app/__pycache__/main.cpython-311.pyc
ADDED
|
Binary file (1.28 kB). View file
|
|
|
app/auth/__pycache__/jwt_handler.cpython-311.pyc
ADDED
|
Binary file (4.04 kB). View file
|
|
|
app/auth/__pycache__/models.cpython-311.pyc
ADDED
|
Binary file (2.29 kB). View file
|
|
|
app/auth/__pycache__/routes.cpython-311.pyc
ADDED
|
Binary file (9.46 kB). View file
|
|
|
app/auth/jwt_handler.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timedelta, timezone
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
+
from jose import jwt, JWTError
|
| 4 |
+
from passlib.hash import argon2
|
| 5 |
+
from fastapi import HTTPException, status
|
| 6 |
+
from app.config import settings
|
| 7 |
+
|
| 8 |
+
SECRET_KEY = settings.secret_key
|
| 9 |
+
ALGORITHM = settings.algorithm
|
| 10 |
+
ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
|
| 11 |
+
REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days
|
| 12 |
+
|
| 13 |
+
def _now() -> datetime:
|
| 14 |
+
return datetime.now(timezone.utc)
|
| 15 |
+
|
| 16 |
+
def create_access_token(subject: str, role: Optional[str] = None) -> str:
|
| 17 |
+
expire = _now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
| 18 |
+
payload = {"sub": subject, "exp": expire, "iat": _now(), "type": "access"}
|
| 19 |
+
if role:
|
| 20 |
+
payload["role"] = role
|
| 21 |
+
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
| 22 |
+
|
| 23 |
+
def create_refresh_token(subject: str) -> str:
|
| 24 |
+
expire = _now() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
|
| 25 |
+
payload = {"sub": subject, "exp": expire, "iat": _now(), "type": "refresh"}
|
| 26 |
+
return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
|
| 27 |
+
|
| 28 |
+
def decode_token(token: str) -> Dict[str, Any]:
|
| 29 |
+
try:
|
| 30 |
+
return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
| 31 |
+
except JWTError:
|
| 32 |
+
raise ValueError("Invalid token or signature")
|
| 33 |
+
|
| 34 |
+
def verify_access_token(token: str) -> str:
|
| 35 |
+
credentials_exception = HTTPException(
|
| 36 |
+
status_code=status.HTTP_401_UNAUTHORIZED,
|
| 37 |
+
detail="Could not validate credentials",
|
| 38 |
+
headers={"WWW-Authenticate": "Bearer"},
|
| 39 |
+
)
|
| 40 |
+
try:
|
| 41 |
+
payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
|
| 42 |
+
subject: Optional[str] = payload.get("sub")
|
| 43 |
+
token_type = payload.get("type")
|
| 44 |
+
if subject is None or token_type != "access":
|
| 45 |
+
raise credentials_exception
|
| 46 |
+
return subject
|
| 47 |
+
except JWTError:
|
| 48 |
+
raise credentials_exception
|
| 49 |
+
|
| 50 |
+
def hash_refresh_token(raw_refresh: str) -> str:
|
| 51 |
+
return argon2.hash(raw_refresh)
|
| 52 |
+
|
| 53 |
+
def verify_refresh_token(raw_refresh: str, hash_value: str) -> bool:
|
| 54 |
+
try:
|
| 55 |
+
return argon2.verify(raw_refresh, hash_value)
|
| 56 |
+
except Exception:
|
| 57 |
+
return False
|
app/auth/models.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from pydantic import BaseModel, Field, EmailStr
|
| 4 |
+
|
| 5 |
+
class UserCreate(BaseModel):
|
| 6 |
+
username: str = Field(..., min_length=3, max_length=50)
|
| 7 |
+
email: EmailStr = Field(..., description="User email (must be unique)")
|
| 8 |
+
company: str = Field(default="", max_length=128)
|
| 9 |
+
password: str = Field(..., min_length=8, description="User password (will be hashed).")
|
| 10 |
+
|
| 11 |
+
class UserInDB(BaseModel):
|
| 12 |
+
username: str
|
| 13 |
+
email: str
|
| 14 |
+
company: str
|
| 15 |
+
hashed_password: str
|
| 16 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 17 |
+
|
| 18 |
+
class UserPublic(BaseModel):
|
| 19 |
+
username: str
|
| 20 |
+
email: str
|
| 21 |
+
company: str
|
| 22 |
+
created_at: datetime
|
| 23 |
+
|
| 24 |
+
class Token(BaseModel):
|
| 25 |
+
access_token: str
|
| 26 |
+
token_type: str = "bearer"
|
| 27 |
+
refresh_token: str
|
| 28 |
+
expires_in: int
|
app/auth/routes.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime, timezone
|
| 2 |
+
from typing import Optional
|
| 3 |
+
from fastapi import APIRouter, Depends, HTTPException, status
|
| 4 |
+
from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
|
| 5 |
+
from motor.motor_asyncio import AsyncIOMotorDatabase
|
| 6 |
+
from passlib.context import CryptContext
|
| 7 |
+
from pydantic import BaseModel
|
| 8 |
+
|
| 9 |
+
from app.database.connection import get_db
|
| 10 |
+
from app.auth.models import UserCreate, UserPublic, Token
|
| 11 |
+
from app.auth.jwt_handler import (
|
| 12 |
+
create_access_token,
|
| 13 |
+
create_refresh_token,
|
| 14 |
+
decode_token,
|
| 15 |
+
hash_refresh_token,
|
| 16 |
+
verify_access_token,
|
| 17 |
+
verify_refresh_token,
|
| 18 |
+
)
|
| 19 |
+
from app.config import settings
|
| 20 |
+
|
| 21 |
+
router = APIRouter(prefix="/auth", tags=["Authentication"])
|
| 22 |
+
pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
|
| 23 |
+
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
| 24 |
+
|
| 25 |
+
async def get_current_user(
|
| 26 |
+
token: str = Depends(oauth2_scheme),
|
| 27 |
+
db: AsyncIOMotorDatabase = Depends(get_db),
|
| 28 |
+
) -> UserPublic:
|
| 29 |
+
try:
|
| 30 |
+
username = verify_access_token(token)
|
| 31 |
+
except HTTPException as e:
|
| 32 |
+
raise e
|
| 33 |
+
except Exception:
|
| 34 |
+
raise HTTPException(status_code=401, detail="Invalid token")
|
| 35 |
+
user = await db.users.find_one({"username": username})
|
| 36 |
+
if not user:
|
| 37 |
+
raise HTTPException(status_code=401, detail="User not found")
|
| 38 |
+
return UserPublic(**user)
|
| 39 |
+
|
| 40 |
+
@router.post("/register", response_model=UserPublic, status_code=status.HTTP_201_CREATED)
|
| 41 |
+
async def register(user: UserCreate, db: AsyncIOMotorDatabase = Depends(get_db)):
|
| 42 |
+
username = user.username.strip().lower()
|
| 43 |
+
email = user.email.lower()
|
| 44 |
+
if await db.users.find_one({"$or": [{"username": username}, {"email": email}]}):
|
| 45 |
+
raise HTTPException(status_code=400, detail="Username or email already exists")
|
| 46 |
+
hashed = pwd_context.hash(user.password)
|
| 47 |
+
doc = {
|
| 48 |
+
"username": username,
|
| 49 |
+
"email": email,
|
| 50 |
+
"company": user.company,
|
| 51 |
+
"hashed_password": hashed,
|
| 52 |
+
"created_at": datetime.utcnow(),
|
| 53 |
+
}
|
| 54 |
+
await db.users.insert_one(doc)
|
| 55 |
+
return UserPublic(**doc)
|
| 56 |
+
|
| 57 |
+
@router.post("/login", response_model=Token)
|
| 58 |
+
async def login(form_data: OAuth2PasswordRequestForm = Depends(), db=Depends(get_db)):
|
| 59 |
+
username = form_data.username.strip().lower()
|
| 60 |
+
user = await db.users.find_one({"username": username})
|
| 61 |
+
if not user or not pwd_context.verify(form_data.password, user["hashed_password"]):
|
| 62 |
+
raise HTTPException(status_code=401, detail="Incorrect username or password")
|
| 63 |
+
access_token = create_access_token(username)
|
| 64 |
+
refresh_token = create_refresh_token(username)
|
| 65 |
+
payload = decode_token(refresh_token)
|
| 66 |
+
await db.sessions.insert_one(
|
| 67 |
+
{
|
| 68 |
+
"user_id": username,
|
| 69 |
+
"refresh_token_hash": hash_refresh_token(refresh_token),
|
| 70 |
+
"created_at": datetime.now(timezone.utc),
|
| 71 |
+
"expires_at": datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
|
| 72 |
+
"revoked_at": None,
|
| 73 |
+
}
|
| 74 |
+
)
|
| 75 |
+
return Token(
|
| 76 |
+
access_token=access_token,
|
| 77 |
+
refresh_token=refresh_token,
|
| 78 |
+
expires_in=settings.access_token_expire_minutes * 60,
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
class RefreshIn(BaseModel):
|
| 82 |
+
refresh_token: str
|
| 83 |
+
|
| 84 |
+
@router.post("/refresh", response_model=Token)
|
| 85 |
+
async def refresh_token(payload: RefreshIn, db=Depends(get_db)):
|
| 86 |
+
try:
|
| 87 |
+
decoded = decode_token(payload.refresh_token)
|
| 88 |
+
except ValueError:
|
| 89 |
+
raise HTTPException(status_code=401, detail="Invalid refresh token")
|
| 90 |
+
if decoded.get("type") != "refresh":
|
| 91 |
+
raise HTTPException(status_code=401, detail="Invalid token type")
|
| 92 |
+
username = decoded.get("sub")
|
| 93 |
+
session_doc = await db.sessions.find_one(
|
| 94 |
+
{
|
| 95 |
+
"user_id": username,
|
| 96 |
+
"revoked_at": None,
|
| 97 |
+
"expires_at": {"$gt": datetime.now(timezone.utc)},
|
| 98 |
+
},
|
| 99 |
+
sort=[("created_at", -1)],
|
| 100 |
+
)
|
| 101 |
+
if not session_doc or not verify_refresh_token(payload.refresh_token, session_doc["refresh_token_hash"]):
|
| 102 |
+
raise HTTPException(status_code=401, detail="Refresh token not recognized")
|
| 103 |
+
new_access = create_access_token(username)
|
| 104 |
+
new_refresh = create_refresh_token(username)
|
| 105 |
+
await db.sessions.update_one(
|
| 106 |
+
{"_id": session_doc["_id"]}, {"$set": {"revoked_at": datetime.now(timezone.utc)}}
|
| 107 |
+
)
|
| 108 |
+
payload_new = decode_token(new_refresh)
|
| 109 |
+
await db.sessions.insert_one(
|
| 110 |
+
{
|
| 111 |
+
"user_id": username,
|
| 112 |
+
"refresh_token_hash": hash_refresh_token(new_refresh),
|
| 113 |
+
"created_at": datetime.now(timezone.utc),
|
| 114 |
+
"expires_at": datetime.fromtimestamp(payload_new["exp"], tz=timezone.utc),
|
| 115 |
+
"revoked_at": None,
|
| 116 |
+
}
|
| 117 |
+
)
|
| 118 |
+
return Token(access_token=new_access, refresh_token=new_refresh, expires_in=settings.access_token_expire_minutes * 60)
|
| 119 |
+
|
| 120 |
+
@router.post("/logout")
|
| 121 |
+
async def logout(payload: RefreshIn, db=Depends(get_db)):
|
| 122 |
+
try:
|
| 123 |
+
decoded = decode_token(payload.refresh_token)
|
| 124 |
+
except ValueError:
|
| 125 |
+
return {"ok": True}
|
| 126 |
+
username = decoded.get("sub")
|
| 127 |
+
session_doc = await db.sessions.find_one(
|
| 128 |
+
{
|
| 129 |
+
"user_id": username,
|
| 130 |
+
"revoked_at": None,
|
| 131 |
+
"expires_at": {"$gt": datetime.now(timezone.utc)},
|
| 132 |
+
},
|
| 133 |
+
sort=[("created_at", -1)],
|
| 134 |
+
)
|
| 135 |
+
if not session_doc:
|
| 136 |
+
return {"ok": True}
|
| 137 |
+
if verify_refresh_token(payload.refresh_token, session_doc["refresh_token_hash"]):
|
| 138 |
+
await db.sessions.update_one(
|
| 139 |
+
{"_id": session_doc["_id"]}, {"$set": {"revoked_at": datetime.now(timezone.utc)}}
|
| 140 |
+
)
|
| 141 |
+
return {"ok": True}
|
| 142 |
+
|
| 143 |
+
@router.get("/profile", response_model=UserPublic)
|
| 144 |
+
async def read_users_me(current_user: UserPublic = Depends(get_current_user)):
|
| 145 |
+
return current_user
|
app/config.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from pydantic_settings import BaseSettings
|
| 3 |
+
|
| 4 |
+
class Settings(BaseSettings):
|
| 5 |
+
mongo_uri: str = os.getenv("MONGO_URI")
|
| 6 |
+
database_name: str = os.getenv("DATABASE_NAME")
|
| 7 |
+
groq_api_key: str = os.getenv("GROQ_API_KEY")
|
| 8 |
+
secret_key: str = os.getenv("SECRET_KEY")
|
| 9 |
+
algorithm: str = os.getenv("ALGORITHM")
|
| 10 |
+
access_token_expire_minutes: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30))
|
| 11 |
+
refresh_token_expire_days: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", 7))
|
| 12 |
+
|
| 13 |
+
class Config:
|
| 14 |
+
env_file = ".env"
|
| 15 |
+
env_file_encoding = "utf-8"
|
| 16 |
+
|
| 17 |
+
settings = Settings()
|
app/database/__pycache__/connection.cpython-311.pyc
ADDED
|
Binary file (625 Bytes). View file
|
|
|
app/database/__pycache__/schemas.cpython-311.pyc
ADDED
|
Binary file (5.06 kB). View file
|
|
|
app/database/connection.py
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
|
| 2 |
+
from app.config import settings
|
| 3 |
+
|
| 4 |
+
client = AsyncIOMotorClient(settings.mongo_uri)
|
| 5 |
+
|
| 6 |
+
async def get_db() -> AsyncIOMotorDatabase:
|
| 7 |
+
return client[settings.database_name]
|
app/database/schemas.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Dict, Any
|
| 2 |
+
from datetime import datetime
|
| 3 |
+
from pydantic import BaseModel, Field, EmailStr
|
| 4 |
+
|
| 5 |
+
# User Schemas (adapted)
|
| 6 |
+
class UserCreate(BaseModel):
|
| 7 |
+
username: str = Field(..., min_length=3, max_length=50)
|
| 8 |
+
email: EmailStr = Field(..., description="User email (must be unique)")
|
| 9 |
+
company: str = Field(default="", max_length=128)
|
| 10 |
+
password: str = Field(..., min_length=8, description="User password (will be hashed).")
|
| 11 |
+
|
| 12 |
+
class UserDB(BaseModel):
|
| 13 |
+
id: Optional[str] = Field(None, alias="_id")
|
| 14 |
+
username: str
|
| 15 |
+
email: str
|
| 16 |
+
password_hash: str = Field(alias="hashed_password")
|
| 17 |
+
company: str
|
| 18 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 19 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 20 |
+
is_active: bool = True
|
| 21 |
+
roles: List[str] = Field(default_factory=lambda: ["user"])
|
| 22 |
+
|
| 23 |
+
class SessionDB(BaseModel):
|
| 24 |
+
id: Optional[str] = Field(None, alias="_id")
|
| 25 |
+
user_id: str
|
| 26 |
+
refresh_token_hash: str
|
| 27 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 28 |
+
expires_at: datetime
|
| 29 |
+
revoked_at: Optional[datetime] = None
|
| 30 |
+
meta: Dict[str, Any] = Field(default_factory=dict)
|
| 31 |
+
|
| 32 |
+
# Conversation Schemas
|
| 33 |
+
class Message(BaseModel):
|
| 34 |
+
role: str
|
| 35 |
+
content: str
|
| 36 |
+
|
| 37 |
+
class ConversationDB(BaseModel):
|
| 38 |
+
id: Optional[str] = Field(None, alias="_id")
|
| 39 |
+
user_id: str
|
| 40 |
+
messages: List[Message] = Field(default_factory=list)
|
| 41 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 42 |
+
updated_at: datetime = Field(default_factory=datetime.utcnow)
|
| 43 |
+
meta: Dict[str, Any] = Field(default_factory=dict) # e.g., {"model": "llama-3.1-8b-instant"}
|
| 44 |
+
|
| 45 |
+
# Audit Log (optional, for security)
|
| 46 |
+
class AuditLogDB(BaseModel):
|
| 47 |
+
id: Optional[str] = Field(None, alias="_id")
|
| 48 |
+
user_id: Optional[str] = None
|
| 49 |
+
action: str
|
| 50 |
+
ip: Optional[str] = None
|
| 51 |
+
user_agent: Optional[str] = None
|
| 52 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 53 |
+
meta: Dict[str, Any] = Field(default_factory=dict)
|
| 54 |
+
|
| 55 |
+
# MongoDB Indexes for Performance
|
| 56 |
+
MONGO_INDEXES = {
|
| 57 |
+
"users": [
|
| 58 |
+
{"keys": [("username", 1)], "unique": True},
|
| 59 |
+
{"keys": [("email", 1)], "unique": True},
|
| 60 |
+
{"keys": [("created_at", -1)]},
|
| 61 |
+
],
|
| 62 |
+
"sessions": [
|
| 63 |
+
{"keys": [("user_id", 1), ("created_at", -1)]},
|
| 64 |
+
{"keys": [("expires_at", 1)]},
|
| 65 |
+
],
|
| 66 |
+
"conversations": [
|
| 67 |
+
{"keys": [("user_id", 1), ("created_at", -1)]},
|
| 68 |
+
],
|
| 69 |
+
"audit_logs": [
|
| 70 |
+
{"keys": [("user_id", 1), ("created_at", -1)]},
|
| 71 |
+
{"keys": [("action", 1), ("created_at", -1)]},
|
| 72 |
+
],
|
| 73 |
+
}
|
app/logging_config.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
import json, sys, os, logging
|
| 3 |
+
from logging.config import dictConfig
|
| 4 |
+
from uvicorn.config import LOG_LEVELS
|
| 5 |
+
|
| 6 |
+
SERVICE_NAME = os.getenv("SERVICE_NAME", "backend")
|
| 7 |
+
ENV = os.getenv("ENV", "production")
|
| 8 |
+
|
| 9 |
+
class JsonFormatter(logging.Formatter):
|
| 10 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 11 |
+
base = {
|
| 12 |
+
"level": record.levelname,
|
| 13 |
+
"logger": record.name,
|
| 14 |
+
"msg": record.getMessage(),
|
| 15 |
+
"time": self.formatTime(record, self.datefmt),
|
| 16 |
+
"service": SERVICE_NAME,
|
| 17 |
+
"env": ENV,
|
| 18 |
+
}
|
| 19 |
+
if record.exc_info:
|
| 20 |
+
base["exc_info"] = self.formatException(record.exc_info)
|
| 21 |
+
return json.dumps(base, ensure_ascii=False)
|
| 22 |
+
|
| 23 |
+
def setup_logging():
|
| 24 |
+
log_level = os.getenv("LOG_LEVEL", "INFO").upper()
|
| 25 |
+
if log_level not in LOG_LEVELS:
|
| 26 |
+
log_level = "INFO"
|
| 27 |
+
|
| 28 |
+
dictConfig({
|
| 29 |
+
"version": 1,
|
| 30 |
+
"disable_existing_loggers": False,
|
| 31 |
+
"formatters": {
|
| 32 |
+
"json": {"()": JsonFormatter},
|
| 33 |
+
"plain": {"format": "%(levelname)s:%(name)s:%(message)s"},
|
| 34 |
+
},
|
| 35 |
+
"handlers": {
|
| 36 |
+
"console": {
|
| 37 |
+
"class": "logging.StreamHandler",
|
| 38 |
+
"stream": sys.stdout,
|
| 39 |
+
"formatter": "json" if ENV == "production" else "plain",
|
| 40 |
+
"level": log_level,
|
| 41 |
+
},
|
| 42 |
+
},
|
| 43 |
+
"loggers": {
|
| 44 |
+
"uvicorn": {"handlers": ["console"], "level": log_level, "propagate": False},
|
| 45 |
+
"uvicorn.error": {"handlers": ["console"], "level": log_level, "propagate": False},
|
| 46 |
+
"uvicorn.access": {"handlers": ["console"], "level": log_level, "propagate": False},
|
| 47 |
+
"fastapi": {"handlers": ["console"], "level": log_level, "propagate": False},
|
| 48 |
+
"": {"handlers": ["console"], "level": log_level}, # root logger
|
| 49 |
+
},
|
| 50 |
+
})
|
app/main.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from app.auth.routes import router as auth_router
|
| 4 |
+
from app.rag.routes import router as rag_router
|
| 5 |
+
from app.logging_config import setup_logging
|
| 6 |
+
|
| 7 |
+
setup_logging()
|
| 8 |
+
|
| 9 |
+
app = FastAPI(title="GrokRAG API", description="SaaS RAG Chat with Groq", version="1.1.0")
|
| 10 |
+
|
| 11 |
+
app.add_middleware(
|
| 12 |
+
CORSMiddleware,
|
| 13 |
+
allow_origins=["http://localhost:3000"],
|
| 14 |
+
allow_credentials=True,
|
| 15 |
+
allow_methods=["*"],
|
| 16 |
+
allow_headers=["*"],
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
app.include_router(auth_router)
|
| 20 |
+
app.include_router(rag_router, prefix="/rag")
|
| 21 |
+
|
| 22 |
+
@app.get("/")
|
| 23 |
+
async def root():
|
| 24 |
+
return {"message": "Welcome to GrokRAG API"}
|
app/rag/__pycache__/models.cpython-311.pyc
ADDED
|
Binary file (2.24 kB). View file
|
|
|
app/rag/__pycache__/rag_processor.cpython-311.pyc
ADDED
|
Binary file (11.6 kB). View file
|
|
|
app/rag/__pycache__/routes.cpython-311.pyc
ADDED
|
Binary file (8.84 kB). View file
|
|
|
app/rag/models.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional
|
| 2 |
+
from pydantic import BaseModel, Field
|
| 3 |
+
from bson import ObjectId
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
|
| 6 |
+
ALLOWED_MODELS = [
|
| 7 |
+
"allam-2-7b", # Fixed typo from "allam-2-7b"
|
| 8 |
+
"groq/compound",
|
| 9 |
+
"groq/compound-mini",
|
| 10 |
+
"llama-3.1-8b-instant",
|
| 11 |
+
"llama-3.3-70b-versatile",
|
| 12 |
+
"llama-3.1-70b-versatile", # Assuming typo from "llama-3.3-70b-versatile"
|
| 13 |
+
"meta-llama/llama-4-maverick-17b-128e-instruct",
|
| 14 |
+
"meta-llama/llama-4-scout-17b-16e-instruct",
|
| 15 |
+
"meta-llama/llama-guard-4-12b",
|
| 16 |
+
"meta-llama/llama-prompt-guard-2-22m",
|
| 17 |
+
"meta-llama/llama-prompt-guard-2-86m",
|
| 18 |
+
"moonshotai/kimi-k2-instruct",
|
| 19 |
+
"moonshotai/kimi-k2-instruct-0905",
|
| 20 |
+
"openai/gpt-oss-120b",
|
| 21 |
+
"openai/gpt-oss-20b",
|
| 22 |
+
"openai/gpt-oss-safeguard-20b",
|
| 23 |
+
"qwen/qwen3-32b",
|
| 24 |
+
]
|
| 25 |
+
|
| 26 |
+
class Message(BaseModel):
|
| 27 |
+
role: str
|
| 28 |
+
content: str
|
| 29 |
+
|
| 30 |
+
class Conversation(BaseModel):
|
| 31 |
+
id: Optional[str] = None # str(ObjectId)
|
| 32 |
+
user_id: str
|
| 33 |
+
messages: List[Message] = []
|
| 34 |
+
created_at: datetime = Field(default_factory=datetime.utcnow)
|
| 35 |
+
|
| 36 |
+
class ChatRequest(BaseModel):
|
| 37 |
+
model: str = Field(..., description="Groq model", enum=ALLOWED_MODELS)
|
| 38 |
+
enable_web_search: bool = False
|
| 39 |
+
message: str = Field(..., min_length=1)
|
app/rag/rag_processor.py
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import re
|
| 3 |
+
import logging
|
| 4 |
+
from typing import List, Tuple, Optional
|
| 5 |
+
import faiss
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
from PyPDF2 import PdfReader
|
| 8 |
+
from docx import Document
|
| 9 |
+
import pytesseract
|
| 10 |
+
from PIL import Image
|
| 11 |
+
import io
|
| 12 |
+
import openpyxl
|
| 13 |
+
import pandas as pd
|
| 14 |
+
from duckduckgo_search import DDGS
|
| 15 |
+
from fastapi import UploadFile
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
_EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
|
| 20 |
+
_embedder: Optional[SentenceTransformer] = None
|
| 21 |
+
|
| 22 |
+
def _get_embedder() -> SentenceTransformer:
|
| 23 |
+
global _embedder
|
| 24 |
+
if _embedder is None:
|
| 25 |
+
logger.info(f"Loading embedding model: {_EMBED_MODEL_NAME}")
|
| 26 |
+
_embedder = SentenceTransformer(_EMBED_MODEL_NAME)
|
| 27 |
+
return _embedder
|
| 28 |
+
|
| 29 |
+
# Enhanced File Extraction
|
| 30 |
+
def extract_text(file: UploadFile) -> str:
|
| 31 |
+
ext = os.path.splitext(file.filename)[1].lower()
|
| 32 |
+
content = file.file.read()
|
| 33 |
+
file_bytes = io.BytesIO(content)
|
| 34 |
+
if ext == ".pdf":
|
| 35 |
+
try:
|
| 36 |
+
reader = PdfReader(file_bytes)
|
| 37 |
+
return "\n".join(page.extract_text() or "" for page in reader.pages)
|
| 38 |
+
except Exception as e:
|
| 39 |
+
logger.error(f"PDF extract failed: {e}")
|
| 40 |
+
return ""
|
| 41 |
+
elif ext == ".docx":
|
| 42 |
+
try:
|
| 43 |
+
doc = Document(file_bytes)
|
| 44 |
+
return "\n".join(p.text for p in doc.paragraphs if p.text)
|
| 45 |
+
except Exception as e:
|
| 46 |
+
logger.error(f"DOCX extract failed: {e}")
|
| 47 |
+
return ""
|
| 48 |
+
elif ext in [".xlsx", ".xls"]:
|
| 49 |
+
try:
|
| 50 |
+
wb = openpyxl.load_workbook(file_bytes, read_only=True, data_only=True)
|
| 51 |
+
text = []
|
| 52 |
+
for sheet in wb:
|
| 53 |
+
for row in sheet.iter_rows(values_only=True):
|
| 54 |
+
text.append(" ".join(str(cell) for cell in row if cell is not None))
|
| 55 |
+
return "\n".join(text)
|
| 56 |
+
except Exception as e:
|
| 57 |
+
logger.error(f"Excel extract failed: {e}")
|
| 58 |
+
return ""
|
| 59 |
+
elif ext == ".csv":
|
| 60 |
+
try:
|
| 61 |
+
df = pd.read_csv(file_bytes)
|
| 62 |
+
return df.to_string()
|
| 63 |
+
except Exception as e:
|
| 64 |
+
logger.error(f"CSV extract failed: {e}")
|
| 65 |
+
return ""
|
| 66 |
+
elif ext in [".jpg", ".jpeg", ".png", ".gif"]: # OCR for images
|
| 67 |
+
try:
|
| 68 |
+
img = Image.open(file_bytes)
|
| 69 |
+
return pytesseract.image_to_string(img)
|
| 70 |
+
except Exception as e:
|
| 71 |
+
logger.error(f"Image OCR failed: {e}")
|
| 72 |
+
return ""
|
| 73 |
+
else: # Fallback text
|
| 74 |
+
try:
|
| 75 |
+
return content.decode("utf-8", errors="ignore")
|
| 76 |
+
except Exception as e:
|
| 77 |
+
logger.error(f"Text extract failed: {e}")
|
| 78 |
+
return ""
|
| 79 |
+
|
| 80 |
+
def clean_text(text: str) -> str:
|
| 81 |
+
t = re.sub(r"[ \t]+", " ", text)
|
| 82 |
+
t = re.sub(r"\n{3,}", "\n\n", t)
|
| 83 |
+
return t.strip()
|
| 84 |
+
|
| 85 |
+
def chunk_text(text: str, max_tokens: int = 400, overlap: int = 50) -> List[str]:
|
| 86 |
+
text = clean_text(text)
|
| 87 |
+
if not text:
|
| 88 |
+
return []
|
| 89 |
+
words = text.split()
|
| 90 |
+
chunks, start = [], 0
|
| 91 |
+
while start < len(words):
|
| 92 |
+
end = min(len(words), start + max_tokens)
|
| 93 |
+
chunk = " ".join(words[start:end]).strip()
|
| 94 |
+
if chunk:
|
| 95 |
+
chunks.append(chunk)
|
| 96 |
+
if end == len(words):
|
| 97 |
+
break
|
| 98 |
+
start = max(0, end - overlap)
|
| 99 |
+
return chunks
|
| 100 |
+
|
| 101 |
+
class RagIndex:
|
| 102 |
+
def __init__(self, index: faiss.IndexFlatIP, dim: int, chunks: List[str]):
|
| 103 |
+
self.index = index
|
| 104 |
+
self.dim = dim
|
| 105 |
+
self.chunks = chunks
|
| 106 |
+
|
| 107 |
+
def build_faiss_index(chunks: List[str]) -> RagIndex:
|
| 108 |
+
emb = _get_embedder()
|
| 109 |
+
vectors = emb.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
|
| 110 |
+
dim = vectors.shape[1]
|
| 111 |
+
index = faiss.IndexFlatIP(dim)
|
| 112 |
+
index.add(vectors)
|
| 113 |
+
return RagIndex(index=index, dim=dim, chunks=chunks)
|
| 114 |
+
|
| 115 |
+
def search(index: RagIndex, query: str, top_k: int = 6) -> List[Tuple[str, float]]:
|
| 116 |
+
emb = _get_embedder()
|
| 117 |
+
q = emb.encode([query], convert_to_numpy=True, normalize_embeddings=True)
|
| 118 |
+
D, I = index.index.search(q, top_k)
|
| 119 |
+
hits = []
|
| 120 |
+
for score, idx in zip(D[0], I[0]):
|
| 121 |
+
if idx == -1:
|
| 122 |
+
continue
|
| 123 |
+
hits.append((index.chunks[idx], float(score)))
|
| 124 |
+
return hits
|
| 125 |
+
|
| 126 |
+
def build_context_from_files(files: List[UploadFile], prompt: str, top_k: int = 6) -> str:
|
| 127 |
+
all_text = []
|
| 128 |
+
for file in files:
|
| 129 |
+
txt = extract_text(file)
|
| 130 |
+
if txt:
|
| 131 |
+
all_text.append(txt)
|
| 132 |
+
file.file.seek(0) # Reset
|
| 133 |
+
big_text = "\n\n".join(all_text)
|
| 134 |
+
chunks = chunk_text(big_text, max_tokens=450, overlap=80)
|
| 135 |
+
if not chunks:
|
| 136 |
+
return ""
|
| 137 |
+
idx = build_faiss_index(chunks)
|
| 138 |
+
hits = search(idx, prompt, top_k=top_k)
|
| 139 |
+
context_sections = [f"[DOC#{i} score={score:.3f}]\n{chunk}" for i, (chunk, score) in enumerate(hits, 1)]
|
| 140 |
+
return "\n\n".join(context_sections)
|
| 141 |
+
|
| 142 |
+
# Web search tool
|
| 143 |
+
def web_search(query: str) -> str:
|
| 144 |
+
try:
|
| 145 |
+
with DDGS() as ddgs:
|
| 146 |
+
results = [r for r in ddgs.text(query, max_results=5)]
|
| 147 |
+
sections = [f"[WEB#{i}] Title: {r['title']}\nSnippet: {r['body']}\nURL: {r['href']}" for i, r in enumerate(results, 1)]
|
| 148 |
+
return "\n\n".join(sections) if sections else "No results found."
|
| 149 |
+
except Exception as e:
|
| 150 |
+
logger.error(f"Web search failed: {e}")
|
| 151 |
+
return "Web search error."
|
app/rag/routes.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, Form
|
| 3 |
+
from fastapi.responses import StreamingResponse
|
| 4 |
+
from motor.motor_asyncio import AsyncIOMotorDatabase
|
| 5 |
+
from bson import ObjectId
|
| 6 |
+
from groq import Groq
|
| 7 |
+
import json
|
| 8 |
+
import logging
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from app.database.connection import get_db
|
| 12 |
+
from app.database.schemas import ConversationDB
|
| 13 |
+
from app.auth.routes import get_current_user
|
| 14 |
+
from app.auth.models import UserPublic
|
| 15 |
+
from app.rag.models import ALLOWED_MODELS, Message
|
| 16 |
+
from app.rag.rag_processor import build_context_from_files, web_search
|
| 17 |
+
from app.config import settings
|
| 18 |
+
|
| 19 |
+
router = APIRouter(tags=["RAG Chat"])
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
SYSTEM_PROMPT = """You are a helpful assistant. Use the provided context if relevant. If web search is enabled and you need up-to-date information, use the web_search tool. Reason step-by-step before deciding to use tools."""
|
| 23 |
+
|
| 24 |
+
WEB_SEARCH_TOOL = {
|
| 25 |
+
"type": "function",
|
| 26 |
+
"function": {
|
| 27 |
+
"name": "web_search",
|
| 28 |
+
"description": "Search the web using DuckDuckGo for up-to-date information.",
|
| 29 |
+
"parameters": {
|
| 30 |
+
"type": "object",
|
| 31 |
+
"properties": {"query": {"type": "string", "description": "The search query"}},
|
| 32 |
+
"required": ["query"],
|
| 33 |
+
},
|
| 34 |
+
},
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
@router.post("/conversations", status_code=status.HTTP_201_CREATED)
|
| 38 |
+
async def create_conversation(
|
| 39 |
+
current_user: UserPublic = Depends(get_current_user),
|
| 40 |
+
db: AsyncIOMotorDatabase = Depends(get_db),
|
| 41 |
+
):
|
| 42 |
+
conv = ConversationDB(user_id=current_user.username)
|
| 43 |
+
result = await db.conversations.insert_one(conv.dict(exclude={"id"}))
|
| 44 |
+
conv.id = str(result.inserted_id)
|
| 45 |
+
return {"conversation_id": conv.id}
|
| 46 |
+
|
| 47 |
+
@router.get("/conversations/{conv_id}")
|
| 48 |
+
async def get_conversation(
|
| 49 |
+
conv_id: str,
|
| 50 |
+
current_user: UserPublic = Depends(get_current_user),
|
| 51 |
+
db: AsyncIOMotorDatabase = Depends(get_db),
|
| 52 |
+
):
|
| 53 |
+
try:
|
| 54 |
+
oid = ObjectId(conv_id)
|
| 55 |
+
except:
|
| 56 |
+
raise HTTPException(status_code=400, detail="Invalid conversation ID")
|
| 57 |
+
conv = await db.conversations.find_one({"_id": oid, "user_id": current_user.username})
|
| 58 |
+
if not conv:
|
| 59 |
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 60 |
+
conv["id"] = str(conv["_id"])
|
| 61 |
+
del conv["_id"]
|
| 62 |
+
return conv
|
| 63 |
+
|
| 64 |
+
@router.post("/conversations/{conv_id}/messages")
|
| 65 |
+
async def send_message(
|
| 66 |
+
conv_id: str,
|
| 67 |
+
model: str = Form(...),
|
| 68 |
+
enable_web_search: bool = Form(False),
|
| 69 |
+
message: str = Form(...),
|
| 70 |
+
files: List[UploadFile] = None,
|
| 71 |
+
current_user: UserPublic = Depends(get_current_user),
|
| 72 |
+
db: AsyncIOMotorDatabase = Depends(get_db),
|
| 73 |
+
):
|
| 74 |
+
if model not in ALLOWED_MODELS:
|
| 75 |
+
raise HTTPException(status_code=400, detail="Invalid model")
|
| 76 |
+
try:
|
| 77 |
+
oid = ObjectId(conv_id)
|
| 78 |
+
except:
|
| 79 |
+
raise HTTPException(status_code=400, detail="Invalid conversation ID")
|
| 80 |
+
conv = await db.conversations.find_one({"_id": oid, "user_id": current_user.username})
|
| 81 |
+
if not conv:
|
| 82 |
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 83 |
+
|
| 84 |
+
# Load messages
|
| 85 |
+
messages = [Message(**m) for m in conv.get("messages", [])]
|
| 86 |
+
|
| 87 |
+
# Build RAG context if files
|
| 88 |
+
rag_context = ""
|
| 89 |
+
if files:
|
| 90 |
+
rag_context = build_context_from_files(files, message)
|
| 91 |
+
|
| 92 |
+
# System prompt with context
|
| 93 |
+
system_msg = {"role": "system", "content": SYSTEM_PROMPT + (f"\n\nContext: {rag_context}" if rag_context else "")}
|
| 94 |
+
|
| 95 |
+
# Append user message
|
| 96 |
+
user_msg = Message(role="user", content=message)
|
| 97 |
+
messages.append(user_msg)
|
| 98 |
+
|
| 99 |
+
# Groq client
|
| 100 |
+
client = Groq(api_key=settings.groq_api_key)
|
| 101 |
+
|
| 102 |
+
# Tools if enabled
|
| 103 |
+
tools = [WEB_SEARCH_TOOL] if enable_web_search else None
|
| 104 |
+
|
| 105 |
+
# Tool loop for reasoning and multiple calls (up to 3 iterations)
|
| 106 |
+
chat_history = [
|
| 107 |
+
system_msg if isinstance(system_msg, dict) else system_msg.dict()
|
| 108 |
+
] + [
|
| 109 |
+
m if isinstance(m, dict) else m.dict() for m in messages
|
| 110 |
+
]
|
| 111 |
+
max_tool_loops = 3
|
| 112 |
+
for _ in range(max_tool_loops):
|
| 113 |
+
completion = client.chat.completions.create(
|
| 114 |
+
model=model,
|
| 115 |
+
messages=chat_history,
|
| 116 |
+
temperature=1,
|
| 117 |
+
max_tokens=8192,
|
| 118 |
+
top_p=1,
|
| 119 |
+
stream=False,
|
| 120 |
+
stop=None,
|
| 121 |
+
tools=tools,
|
| 122 |
+
)
|
| 123 |
+
choice = completion.choices[0].message
|
| 124 |
+
if not choice.tool_calls:
|
| 125 |
+
# No more tools, prepare to stream
|
| 126 |
+
break
|
| 127 |
+
for tool_call in choice.tool_calls:
|
| 128 |
+
if tool_call.function.name == "web_search":
|
| 129 |
+
args = json.loads(tool_call.function.arguments)
|
| 130 |
+
query = args["query"]
|
| 131 |
+
result = web_search(query)
|
| 132 |
+
tool_response = {
|
| 133 |
+
"role": "tool",
|
| 134 |
+
"tool_call_id": tool_call.id,
|
| 135 |
+
"name": "web_search",
|
| 136 |
+
"content": result,
|
| 137 |
+
}
|
| 138 |
+
chat_history.append(tool_response)
|
| 139 |
+
else:
|
| 140 |
+
logger.warning("Max tool loops reached")
|
| 141 |
+
raise HTTPException(status_code=500, detail="Too many tool calls")
|
| 142 |
+
|
| 143 |
+
# Final streaming call
|
| 144 |
+
completion = client.chat.completions.create(
|
| 145 |
+
model=model,
|
| 146 |
+
messages=chat_history,
|
| 147 |
+
temperature=1,
|
| 148 |
+
max_tokens=8192,
|
| 149 |
+
top_p=1,
|
| 150 |
+
stream=True,
|
| 151 |
+
stop=None,
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Stream response
|
| 155 |
+
async def generate():
|
| 156 |
+
response_content = ""
|
| 157 |
+
for chunk in completion:
|
| 158 |
+
content = chunk.choices[0].delta.content or ""
|
| 159 |
+
response_content += content
|
| 160 |
+
yield content
|
| 161 |
+
# Save to DB
|
| 162 |
+
messages.append(Message(role="assistant", content=response_content))
|
| 163 |
+
await db.conversations.update_one(
|
| 164 |
+
{"_id": oid},
|
| 165 |
+
{"$set": {"messages": [m.dict() for m in messages], "updated_at": datetime.utcnow()}}
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
return StreamingResponse(generate(), media_type="text/event-stream")
|
app/request.py
ADDED
|
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""import requests
|
| 2 |
+
import json
|
| 3 |
+
|
| 4 |
+
# Base URL
|
| 5 |
+
BASE_URL = "http://localhost:8000"
|
| 6 |
+
|
| 7 |
+
# 1. Register a new user
|
| 8 |
+
register_data = {
|
| 9 |
+
"username": "testuser",
|
| 10 |
+
"email": "test@example.com",
|
| 11 |
+
"company": "Test Co",
|
| 12 |
+
"password": "securepassword123"
|
| 13 |
+
}
|
| 14 |
+
response = requests.post(f"{BASE_URL}/auth/register", json=register_data)
|
| 15 |
+
print("Register:", response.json())
|
| 16 |
+
|
| 17 |
+
# 2. Login (get access/refresh tokens)
|
| 18 |
+
login_data = {
|
| 19 |
+
"username": "testuser",
|
| 20 |
+
"password": "securepassword123"
|
| 21 |
+
}
|
| 22 |
+
response = requests.post(f"{BASE_URL}/auth/login", data=login_data)
|
| 23 |
+
tokens = response.json()
|
| 24 |
+
access_token = tokens["access_token"]
|
| 25 |
+
refresh_token = tokens["refresh_token"]
|
| 26 |
+
print("Login:", tokens)
|
| 27 |
+
|
| 28 |
+
# Headers for authenticated requests
|
| 29 |
+
headers = {"Authorization": f"Bearer {access_token}"}
|
| 30 |
+
|
| 31 |
+
# 3. Create a conversation
|
| 32 |
+
response = requests.post(f"{BASE_URL}/rag/conversations", headers=headers)
|
| 33 |
+
conv_id = response.json()["conversation_id"]
|
| 34 |
+
print("Conversation ID:", conv_id)
|
| 35 |
+
|
| 36 |
+
# 4. Send a message (with optional files, web search)
|
| 37 |
+
# Example: Text-only message
|
| 38 |
+
files = [] # Or: [('files', open('doc.pdf', 'rb'))] for uploads
|
| 39 |
+
data = {
|
| 40 |
+
"model": "llama-3.1-8b-instant",
|
| 41 |
+
"enable_web_search": True,
|
| 42 |
+
"message": "What is the capital of France?"
|
| 43 |
+
}
|
| 44 |
+
response = requests.post(
|
| 45 |
+
f"{BASE_URL}/rag/conversations/{conv_id}/messages",
|
| 46 |
+
headers=headers,
|
| 47 |
+
data=data,
|
| 48 |
+
files=files if files else None,
|
| 49 |
+
stream=True
|
| 50 |
+
)
|
| 51 |
+
for chunk in response.iter_content(chunk_size=1024):
|
| 52 |
+
if chunk:
|
| 53 |
+
print(chunk.decode(), end='', flush=True) # Streaming output
|
| 54 |
+
|
| 55 |
+
# 5. Get conversation history
|
| 56 |
+
response = requests.get(f"{BASE_URL}/rag/conversations/{conv_id}", headers=headers)
|
| 57 |
+
print("History:", response.json())
|
| 58 |
+
|
| 59 |
+
# 6. Refresh token
|
| 60 |
+
refresh_data = {"refresh_token": refresh_token}
|
| 61 |
+
response = requests.post(f"{BASE_URL}/auth/refresh", json=refresh_data)
|
| 62 |
+
new_tokens = response.json()
|
| 63 |
+
print("New Tokens:", new_tokens)
|
| 64 |
+
|
| 65 |
+
# 7. Logout
|
| 66 |
+
logout_data = {"refresh_token": refresh_token}
|
| 67 |
+
response = requests.post(f"{BASE_URL}/auth/logout", json=logout_data)
|
| 68 |
+
print("Logout:", response.json())"""
|
| 69 |
+
import requests
|
| 70 |
+
import json
|
| 71 |
+
|
| 72 |
+
# Base URL
|
| 73 |
+
BASE_URL = "http://localhost:8000"
|
| 74 |
+
|
| 75 |
+
# 1. Login (get access/refresh tokens) - Change credentials if needed
|
| 76 |
+
login_data = {
|
| 77 |
+
"username": "testuser", # Update if your username is different
|
| 78 |
+
"password": "securepassword123" # Update with your actual password
|
| 79 |
+
}
|
| 80 |
+
response = requests.post(f"{BASE_URL}/auth/login", data=login_data)
|
| 81 |
+
|
| 82 |
+
if response.status_code != 200:
|
| 83 |
+
print("Login Failed:", response.status_code, response.text)
|
| 84 |
+
else:
|
| 85 |
+
tokens = response.json()
|
| 86 |
+
access_token = tokens["access_token"]
|
| 87 |
+
refresh_token = tokens["refresh_token"]
|
| 88 |
+
print("Login Success:", tokens)
|
| 89 |
+
|
| 90 |
+
# Headers for authenticated requests
|
| 91 |
+
headers = {"Authorization": f"Bearer {access_token}"}
|
| 92 |
+
|
| 93 |
+
# 2. Create a conversation
|
| 94 |
+
response = requests.post(f"{BASE_URL}/rag/conversations", headers=headers)
|
| 95 |
+
if response.status_code == 201:
|
| 96 |
+
conv_id = response.json()["conversation_id"]
|
| 97 |
+
print("Conversation Created - ID:", conv_id)
|
| 98 |
+
else:
|
| 99 |
+
print("Failed to create conversation:", response.status_code, response.text)
|
| 100 |
+
conv_id = None
|
| 101 |
+
|
| 102 |
+
if conv_id:
|
| 103 |
+
# 3. Send a message (text-only example)
|
| 104 |
+
data = {
|
| 105 |
+
"model": "llama-3.1-8b-instant", # Change model if desired (from ALLOWED_MODELS)
|
| 106 |
+
"enable_web_search": "true", # "true" or "false" as string for form data
|
| 107 |
+
"message": "What is the capital of France?"
|
| 108 |
+
}
|
| 109 |
+
# Optional: Add files for document RAG
|
| 110 |
+
# files = [('files', open('your_document.pdf', 'rb'))]
|
| 111 |
+
|
| 112 |
+
response = requests.post(
|
| 113 |
+
f"{BASE_URL}/rag/conversations/{conv_id}/messages",
|
| 114 |
+
headers=headers,
|
| 115 |
+
data=data,
|
| 116 |
+
# files=files if 'files' in locals() else None,
|
| 117 |
+
stream=True
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
print("\n--- Assistant Response ---")
|
| 121 |
+
if response.status_code == 200:
|
| 122 |
+
for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
|
| 123 |
+
if chunk:
|
| 124 |
+
print(chunk, end='', flush=True)
|
| 125 |
+
print("\n--- End of Response ---")
|
| 126 |
+
else:
|
| 127 |
+
print("Message Send Failed:", response.status_code)
|
| 128 |
+
print("Response:", response.text)
|
| 129 |
+
|
| 130 |
+
# 4. Get conversation history
|
| 131 |
+
response = requests.get(f"{BASE_URL}/rag/conversations/{conv_id}", headers=headers)
|
| 132 |
+
print("\nConversation History:", json.dumps(response.json(), indent=2))
|
| 133 |
+
|
| 134 |
+
# 5. Refresh token (optional)
|
| 135 |
+
refresh_data = {"refresh_token": refresh_token}
|
| 136 |
+
response = requests.post(f"{BASE_URL}/auth/refresh", json=refresh_data)
|
| 137 |
+
print("Token Refresh:", response.json() if response.status_code == 200 else response.text)
|
| 138 |
+
|
| 139 |
+
# 6. Logout (optional)
|
| 140 |
+
logout_data = {"refresh_token": refresh_token}
|
| 141 |
+
response = requests.post(f"{BASE_URL}/auth/logout", json=logout_data)
|
| 142 |
+
print("Logout:", response.json())
|
requirements.txt
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi
|
| 2 |
+
uvicorn==0.24.0
|
| 3 |
+
pydantic==2.7.2
|
| 4 |
+
motor==3.3.2
|
| 5 |
+
passlib[argon2]==1.7.4
|
| 6 |
+
python-jose[cryptography]==3.3.0
|
| 7 |
+
sentence-transformers>=2.2.2,<3
|
| 8 |
+
faiss-cpu==1.7.4
|
| 9 |
+
PyPDF2==3.0.1
|
| 10 |
+
python-docx==1.1.0
|
| 11 |
+
duckduckgo-search==6.2.13
|
| 12 |
+
huggingface_hub>=0.17.0,<1.0
|
| 13 |
+
transformers<5,>=4.41.2
|
| 14 |
+
tokenizers<0.20,>=0.19.1
|
| 15 |
+
groq==0.11.0 # Or latest: pip install groq --upgrade
|
| 16 |
+
python-multipart==0.0.9
|
| 17 |
+
faiss-cpu==1.7.4
|