precison9 commited on
Commit
62a1756
·
1 Parent(s): e9d753d

deploy FastAPI backend

Browse files
app/__pycache__/config.cpython-311.pyc ADDED
Binary file (1.66 kB). View file
 
app/__pycache__/logging_config.cpython-311.pyc ADDED
Binary file (2.65 kB). View file
 
app/__pycache__/main.cpython-311.pyc ADDED
Binary file (1.28 kB). View file
 
app/auth/__pycache__/jwt_handler.cpython-311.pyc ADDED
Binary file (4.04 kB). View file
 
app/auth/__pycache__/models.cpython-311.pyc ADDED
Binary file (2.29 kB). View file
 
app/auth/__pycache__/routes.cpython-311.pyc ADDED
Binary file (9.46 kB). View file
 
app/auth/jwt_handler.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timedelta, timezone
2
+ from typing import Any, Dict, Optional
3
+ from jose import jwt, JWTError
4
+ from passlib.hash import argon2
5
+ from fastapi import HTTPException, status
6
+ from app.config import settings
7
+
8
+ SECRET_KEY = settings.secret_key
9
+ ALGORITHM = settings.algorithm
10
+ ACCESS_TOKEN_EXPIRE_MINUTES = settings.access_token_expire_minutes
11
+ REFRESH_TOKEN_EXPIRE_DAYS = settings.refresh_token_expire_days
12
+
13
+ def _now() -> datetime:
14
+ return datetime.now(timezone.utc)
15
+
16
+ def create_access_token(subject: str, role: Optional[str] = None) -> str:
17
+ expire = _now() + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
18
+ payload = {"sub": subject, "exp": expire, "iat": _now(), "type": "access"}
19
+ if role:
20
+ payload["role"] = role
21
+ return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
22
+
23
+ def create_refresh_token(subject: str) -> str:
24
+ expire = _now() + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
25
+ payload = {"sub": subject, "exp": expire, "iat": _now(), "type": "refresh"}
26
+ return jwt.encode(payload, SECRET_KEY, algorithm=ALGORITHM)
27
+
28
+ def decode_token(token: str) -> Dict[str, Any]:
29
+ try:
30
+ return jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
31
+ except JWTError:
32
+ raise ValueError("Invalid token or signature")
33
+
34
+ def verify_access_token(token: str) -> str:
35
+ credentials_exception = HTTPException(
36
+ status_code=status.HTTP_401_UNAUTHORIZED,
37
+ detail="Could not validate credentials",
38
+ headers={"WWW-Authenticate": "Bearer"},
39
+ )
40
+ try:
41
+ payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM])
42
+ subject: Optional[str] = payload.get("sub")
43
+ token_type = payload.get("type")
44
+ if subject is None or token_type != "access":
45
+ raise credentials_exception
46
+ return subject
47
+ except JWTError:
48
+ raise credentials_exception
49
+
50
+ def hash_refresh_token(raw_refresh: str) -> str:
51
+ return argon2.hash(raw_refresh)
52
+
53
+ def verify_refresh_token(raw_refresh: str, hash_value: str) -> bool:
54
+ try:
55
+ return argon2.verify(raw_refresh, hash_value)
56
+ except Exception:
57
+ return False
app/auth/models.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ from typing import Optional
3
+ from pydantic import BaseModel, Field, EmailStr
4
+
5
+ class UserCreate(BaseModel):
6
+ username: str = Field(..., min_length=3, max_length=50)
7
+ email: EmailStr = Field(..., description="User email (must be unique)")
8
+ company: str = Field(default="", max_length=128)
9
+ password: str = Field(..., min_length=8, description="User password (will be hashed).")
10
+
11
+ class UserInDB(BaseModel):
12
+ username: str
13
+ email: str
14
+ company: str
15
+ hashed_password: str
16
+ created_at: datetime = Field(default_factory=datetime.utcnow)
17
+
18
+ class UserPublic(BaseModel):
19
+ username: str
20
+ email: str
21
+ company: str
22
+ created_at: datetime
23
+
24
+ class Token(BaseModel):
25
+ access_token: str
26
+ token_type: str = "bearer"
27
+ refresh_token: str
28
+ expires_in: int
app/auth/routes.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime, timezone
2
+ from typing import Optional
3
+ from fastapi import APIRouter, Depends, HTTPException, status
4
+ from fastapi.security import OAuth2PasswordRequestForm, OAuth2PasswordBearer
5
+ from motor.motor_asyncio import AsyncIOMotorDatabase
6
+ from passlib.context import CryptContext
7
+ from pydantic import BaseModel
8
+
9
+ from app.database.connection import get_db
10
+ from app.auth.models import UserCreate, UserPublic, Token
11
+ from app.auth.jwt_handler import (
12
+ create_access_token,
13
+ create_refresh_token,
14
+ decode_token,
15
+ hash_refresh_token,
16
+ verify_access_token,
17
+ verify_refresh_token,
18
+ )
19
+ from app.config import settings
20
+
21
+ router = APIRouter(prefix="/auth", tags=["Authentication"])
22
+ pwd_context = CryptContext(schemes=["argon2"], deprecated="auto")
23
+ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
24
+
25
+ async def get_current_user(
26
+ token: str = Depends(oauth2_scheme),
27
+ db: AsyncIOMotorDatabase = Depends(get_db),
28
+ ) -> UserPublic:
29
+ try:
30
+ username = verify_access_token(token)
31
+ except HTTPException as e:
32
+ raise e
33
+ except Exception:
34
+ raise HTTPException(status_code=401, detail="Invalid token")
35
+ user = await db.users.find_one({"username": username})
36
+ if not user:
37
+ raise HTTPException(status_code=401, detail="User not found")
38
+ return UserPublic(**user)
39
+
40
+ @router.post("/register", response_model=UserPublic, status_code=status.HTTP_201_CREATED)
41
+ async def register(user: UserCreate, db: AsyncIOMotorDatabase = Depends(get_db)):
42
+ username = user.username.strip().lower()
43
+ email = user.email.lower()
44
+ if await db.users.find_one({"$or": [{"username": username}, {"email": email}]}):
45
+ raise HTTPException(status_code=400, detail="Username or email already exists")
46
+ hashed = pwd_context.hash(user.password)
47
+ doc = {
48
+ "username": username,
49
+ "email": email,
50
+ "company": user.company,
51
+ "hashed_password": hashed,
52
+ "created_at": datetime.utcnow(),
53
+ }
54
+ await db.users.insert_one(doc)
55
+ return UserPublic(**doc)
56
+
57
+ @router.post("/login", response_model=Token)
58
+ async def login(form_data: OAuth2PasswordRequestForm = Depends(), db=Depends(get_db)):
59
+ username = form_data.username.strip().lower()
60
+ user = await db.users.find_one({"username": username})
61
+ if not user or not pwd_context.verify(form_data.password, user["hashed_password"]):
62
+ raise HTTPException(status_code=401, detail="Incorrect username or password")
63
+ access_token = create_access_token(username)
64
+ refresh_token = create_refresh_token(username)
65
+ payload = decode_token(refresh_token)
66
+ await db.sessions.insert_one(
67
+ {
68
+ "user_id": username,
69
+ "refresh_token_hash": hash_refresh_token(refresh_token),
70
+ "created_at": datetime.now(timezone.utc),
71
+ "expires_at": datetime.fromtimestamp(payload["exp"], tz=timezone.utc),
72
+ "revoked_at": None,
73
+ }
74
+ )
75
+ return Token(
76
+ access_token=access_token,
77
+ refresh_token=refresh_token,
78
+ expires_in=settings.access_token_expire_minutes * 60,
79
+ )
80
+
81
+ class RefreshIn(BaseModel):
82
+ refresh_token: str
83
+
84
+ @router.post("/refresh", response_model=Token)
85
+ async def refresh_token(payload: RefreshIn, db=Depends(get_db)):
86
+ try:
87
+ decoded = decode_token(payload.refresh_token)
88
+ except ValueError:
89
+ raise HTTPException(status_code=401, detail="Invalid refresh token")
90
+ if decoded.get("type") != "refresh":
91
+ raise HTTPException(status_code=401, detail="Invalid token type")
92
+ username = decoded.get("sub")
93
+ session_doc = await db.sessions.find_one(
94
+ {
95
+ "user_id": username,
96
+ "revoked_at": None,
97
+ "expires_at": {"$gt": datetime.now(timezone.utc)},
98
+ },
99
+ sort=[("created_at", -1)],
100
+ )
101
+ if not session_doc or not verify_refresh_token(payload.refresh_token, session_doc["refresh_token_hash"]):
102
+ raise HTTPException(status_code=401, detail="Refresh token not recognized")
103
+ new_access = create_access_token(username)
104
+ new_refresh = create_refresh_token(username)
105
+ await db.sessions.update_one(
106
+ {"_id": session_doc["_id"]}, {"$set": {"revoked_at": datetime.now(timezone.utc)}}
107
+ )
108
+ payload_new = decode_token(new_refresh)
109
+ await db.sessions.insert_one(
110
+ {
111
+ "user_id": username,
112
+ "refresh_token_hash": hash_refresh_token(new_refresh),
113
+ "created_at": datetime.now(timezone.utc),
114
+ "expires_at": datetime.fromtimestamp(payload_new["exp"], tz=timezone.utc),
115
+ "revoked_at": None,
116
+ }
117
+ )
118
+ return Token(access_token=new_access, refresh_token=new_refresh, expires_in=settings.access_token_expire_minutes * 60)
119
+
120
+ @router.post("/logout")
121
+ async def logout(payload: RefreshIn, db=Depends(get_db)):
122
+ try:
123
+ decoded = decode_token(payload.refresh_token)
124
+ except ValueError:
125
+ return {"ok": True}
126
+ username = decoded.get("sub")
127
+ session_doc = await db.sessions.find_one(
128
+ {
129
+ "user_id": username,
130
+ "revoked_at": None,
131
+ "expires_at": {"$gt": datetime.now(timezone.utc)},
132
+ },
133
+ sort=[("created_at", -1)],
134
+ )
135
+ if not session_doc:
136
+ return {"ok": True}
137
+ if verify_refresh_token(payload.refresh_token, session_doc["refresh_token_hash"]):
138
+ await db.sessions.update_one(
139
+ {"_id": session_doc["_id"]}, {"$set": {"revoked_at": datetime.now(timezone.utc)}}
140
+ )
141
+ return {"ok": True}
142
+
143
+ @router.get("/profile", response_model=UserPublic)
144
+ async def read_users_me(current_user: UserPublic = Depends(get_current_user)):
145
+ return current_user
app/config.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pydantic_settings import BaseSettings
3
+
4
+ class Settings(BaseSettings):
5
+ mongo_uri: str = os.getenv("MONGO_URI")
6
+ database_name: str = os.getenv("DATABASE_NAME")
7
+ groq_api_key: str = os.getenv("GROQ_API_KEY")
8
+ secret_key: str = os.getenv("SECRET_KEY")
9
+ algorithm: str = os.getenv("ALGORITHM")
10
+ access_token_expire_minutes: int = int(os.getenv("ACCESS_TOKEN_EXPIRE_MINUTES", 30))
11
+ refresh_token_expire_days: int = int(os.getenv("REFRESH_TOKEN_EXPIRE_DAYS", 7))
12
+
13
+ class Config:
14
+ env_file = ".env"
15
+ env_file_encoding = "utf-8"
16
+
17
+ settings = Settings()
app/database/__pycache__/connection.cpython-311.pyc ADDED
Binary file (625 Bytes). View file
 
app/database/__pycache__/schemas.cpython-311.pyc ADDED
Binary file (5.06 kB). View file
 
app/database/connection.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from motor.motor_asyncio import AsyncIOMotorClient, AsyncIOMotorDatabase
2
+ from app.config import settings
3
+
4
+ client = AsyncIOMotorClient(settings.mongo_uri)
5
+
6
+ async def get_db() -> AsyncIOMotorDatabase:
7
+ return client[settings.database_name]
app/database/schemas.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Dict, Any
2
+ from datetime import datetime
3
+ from pydantic import BaseModel, Field, EmailStr
4
+
5
+ # User Schemas (adapted)
6
+ class UserCreate(BaseModel):
7
+ username: str = Field(..., min_length=3, max_length=50)
8
+ email: EmailStr = Field(..., description="User email (must be unique)")
9
+ company: str = Field(default="", max_length=128)
10
+ password: str = Field(..., min_length=8, description="User password (will be hashed).")
11
+
12
+ class UserDB(BaseModel):
13
+ id: Optional[str] = Field(None, alias="_id")
14
+ username: str
15
+ email: str
16
+ password_hash: str = Field(alias="hashed_password")
17
+ company: str
18
+ created_at: datetime = Field(default_factory=datetime.utcnow)
19
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
20
+ is_active: bool = True
21
+ roles: List[str] = Field(default_factory=lambda: ["user"])
22
+
23
+ class SessionDB(BaseModel):
24
+ id: Optional[str] = Field(None, alias="_id")
25
+ user_id: str
26
+ refresh_token_hash: str
27
+ created_at: datetime = Field(default_factory=datetime.utcnow)
28
+ expires_at: datetime
29
+ revoked_at: Optional[datetime] = None
30
+ meta: Dict[str, Any] = Field(default_factory=dict)
31
+
32
+ # Conversation Schemas
33
+ class Message(BaseModel):
34
+ role: str
35
+ content: str
36
+
37
+ class ConversationDB(BaseModel):
38
+ id: Optional[str] = Field(None, alias="_id")
39
+ user_id: str
40
+ messages: List[Message] = Field(default_factory=list)
41
+ created_at: datetime = Field(default_factory=datetime.utcnow)
42
+ updated_at: datetime = Field(default_factory=datetime.utcnow)
43
+ meta: Dict[str, Any] = Field(default_factory=dict) # e.g., {"model": "llama-3.1-8b-instant"}
44
+
45
+ # Audit Log (optional, for security)
46
+ class AuditLogDB(BaseModel):
47
+ id: Optional[str] = Field(None, alias="_id")
48
+ user_id: Optional[str] = None
49
+ action: str
50
+ ip: Optional[str] = None
51
+ user_agent: Optional[str] = None
52
+ created_at: datetime = Field(default_factory=datetime.utcnow)
53
+ meta: Dict[str, Any] = Field(default_factory=dict)
54
+
55
+ # MongoDB Indexes for Performance
56
+ MONGO_INDEXES = {
57
+ "users": [
58
+ {"keys": [("username", 1)], "unique": True},
59
+ {"keys": [("email", 1)], "unique": True},
60
+ {"keys": [("created_at", -1)]},
61
+ ],
62
+ "sessions": [
63
+ {"keys": [("user_id", 1), ("created_at", -1)]},
64
+ {"keys": [("expires_at", 1)]},
65
+ ],
66
+ "conversations": [
67
+ {"keys": [("user_id", 1), ("created_at", -1)]},
68
+ ],
69
+ "audit_logs": [
70
+ {"keys": [("user_id", 1), ("created_at", -1)]},
71
+ {"keys": [("action", 1), ("created_at", -1)]},
72
+ ],
73
+ }
app/logging_config.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import json, sys, os, logging
3
+ from logging.config import dictConfig
4
+ from uvicorn.config import LOG_LEVELS
5
+
6
+ SERVICE_NAME = os.getenv("SERVICE_NAME", "backend")
7
+ ENV = os.getenv("ENV", "production")
8
+
9
+ class JsonFormatter(logging.Formatter):
10
+ def format(self, record: logging.LogRecord) -> str:
11
+ base = {
12
+ "level": record.levelname,
13
+ "logger": record.name,
14
+ "msg": record.getMessage(),
15
+ "time": self.formatTime(record, self.datefmt),
16
+ "service": SERVICE_NAME,
17
+ "env": ENV,
18
+ }
19
+ if record.exc_info:
20
+ base["exc_info"] = self.formatException(record.exc_info)
21
+ return json.dumps(base, ensure_ascii=False)
22
+
23
+ def setup_logging():
24
+ log_level = os.getenv("LOG_LEVEL", "INFO").upper()
25
+ if log_level not in LOG_LEVELS:
26
+ log_level = "INFO"
27
+
28
+ dictConfig({
29
+ "version": 1,
30
+ "disable_existing_loggers": False,
31
+ "formatters": {
32
+ "json": {"()": JsonFormatter},
33
+ "plain": {"format": "%(levelname)s:%(name)s:%(message)s"},
34
+ },
35
+ "handlers": {
36
+ "console": {
37
+ "class": "logging.StreamHandler",
38
+ "stream": sys.stdout,
39
+ "formatter": "json" if ENV == "production" else "plain",
40
+ "level": log_level,
41
+ },
42
+ },
43
+ "loggers": {
44
+ "uvicorn": {"handlers": ["console"], "level": log_level, "propagate": False},
45
+ "uvicorn.error": {"handlers": ["console"], "level": log_level, "propagate": False},
46
+ "uvicorn.access": {"handlers": ["console"], "level": log_level, "propagate": False},
47
+ "fastapi": {"handlers": ["console"], "level": log_level, "propagate": False},
48
+ "": {"handlers": ["console"], "level": log_level}, # root logger
49
+ },
50
+ })
app/main.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from app.auth.routes import router as auth_router
4
+ from app.rag.routes import router as rag_router
5
+ from app.logging_config import setup_logging
6
+
7
+ setup_logging()
8
+
9
+ app = FastAPI(title="GrokRAG API", description="SaaS RAG Chat with Groq", version="1.1.0")
10
+
11
+ app.add_middleware(
12
+ CORSMiddleware,
13
+ allow_origins=["http://localhost:3000"],
14
+ allow_credentials=True,
15
+ allow_methods=["*"],
16
+ allow_headers=["*"],
17
+ )
18
+
19
+ app.include_router(auth_router)
20
+ app.include_router(rag_router, prefix="/rag")
21
+
22
+ @app.get("/")
23
+ async def root():
24
+ return {"message": "Welcome to GrokRAG API"}
app/rag/__pycache__/models.cpython-311.pyc ADDED
Binary file (2.24 kB). View file
 
app/rag/__pycache__/rag_processor.cpython-311.pyc ADDED
Binary file (11.6 kB). View file
 
app/rag/__pycache__/routes.cpython-311.pyc ADDED
Binary file (8.84 kB). View file
 
app/rag/models.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional
2
+ from pydantic import BaseModel, Field
3
+ from bson import ObjectId
4
+ from datetime import datetime
5
+
6
+ ALLOWED_MODELS = [
7
+ "allam-2-7b", # Fixed typo from "allam-2-7b"
8
+ "groq/compound",
9
+ "groq/compound-mini",
10
+ "llama-3.1-8b-instant",
11
+ "llama-3.3-70b-versatile",
12
+ "llama-3.1-70b-versatile", # Assuming typo from "llama-3.3-70b-versatile"
13
+ "meta-llama/llama-4-maverick-17b-128e-instruct",
14
+ "meta-llama/llama-4-scout-17b-16e-instruct",
15
+ "meta-llama/llama-guard-4-12b",
16
+ "meta-llama/llama-prompt-guard-2-22m",
17
+ "meta-llama/llama-prompt-guard-2-86m",
18
+ "moonshotai/kimi-k2-instruct",
19
+ "moonshotai/kimi-k2-instruct-0905",
20
+ "openai/gpt-oss-120b",
21
+ "openai/gpt-oss-20b",
22
+ "openai/gpt-oss-safeguard-20b",
23
+ "qwen/qwen3-32b",
24
+ ]
25
+
26
+ class Message(BaseModel):
27
+ role: str
28
+ content: str
29
+
30
+ class Conversation(BaseModel):
31
+ id: Optional[str] = None # str(ObjectId)
32
+ user_id: str
33
+ messages: List[Message] = []
34
+ created_at: datetime = Field(default_factory=datetime.utcnow)
35
+
36
+ class ChatRequest(BaseModel):
37
+ model: str = Field(..., description="Groq model", enum=ALLOWED_MODELS)
38
+ enable_web_search: bool = False
39
+ message: str = Field(..., min_length=1)
app/rag/rag_processor.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import logging
4
+ from typing import List, Tuple, Optional
5
+ import faiss
6
+ from sentence_transformers import SentenceTransformer
7
+ from PyPDF2 import PdfReader
8
+ from docx import Document
9
+ import pytesseract
10
+ from PIL import Image
11
+ import io
12
+ import openpyxl
13
+ import pandas as pd
14
+ from duckduckgo_search import DDGS
15
+ from fastapi import UploadFile
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+ _EMBED_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
20
+ _embedder: Optional[SentenceTransformer] = None
21
+
22
+ def _get_embedder() -> SentenceTransformer:
23
+ global _embedder
24
+ if _embedder is None:
25
+ logger.info(f"Loading embedding model: {_EMBED_MODEL_NAME}")
26
+ _embedder = SentenceTransformer(_EMBED_MODEL_NAME)
27
+ return _embedder
28
+
29
+ # Enhanced File Extraction
30
+ def extract_text(file: UploadFile) -> str:
31
+ ext = os.path.splitext(file.filename)[1].lower()
32
+ content = file.file.read()
33
+ file_bytes = io.BytesIO(content)
34
+ if ext == ".pdf":
35
+ try:
36
+ reader = PdfReader(file_bytes)
37
+ return "\n".join(page.extract_text() or "" for page in reader.pages)
38
+ except Exception as e:
39
+ logger.error(f"PDF extract failed: {e}")
40
+ return ""
41
+ elif ext == ".docx":
42
+ try:
43
+ doc = Document(file_bytes)
44
+ return "\n".join(p.text for p in doc.paragraphs if p.text)
45
+ except Exception as e:
46
+ logger.error(f"DOCX extract failed: {e}")
47
+ return ""
48
+ elif ext in [".xlsx", ".xls"]:
49
+ try:
50
+ wb = openpyxl.load_workbook(file_bytes, read_only=True, data_only=True)
51
+ text = []
52
+ for sheet in wb:
53
+ for row in sheet.iter_rows(values_only=True):
54
+ text.append(" ".join(str(cell) for cell in row if cell is not None))
55
+ return "\n".join(text)
56
+ except Exception as e:
57
+ logger.error(f"Excel extract failed: {e}")
58
+ return ""
59
+ elif ext == ".csv":
60
+ try:
61
+ df = pd.read_csv(file_bytes)
62
+ return df.to_string()
63
+ except Exception as e:
64
+ logger.error(f"CSV extract failed: {e}")
65
+ return ""
66
+ elif ext in [".jpg", ".jpeg", ".png", ".gif"]: # OCR for images
67
+ try:
68
+ img = Image.open(file_bytes)
69
+ return pytesseract.image_to_string(img)
70
+ except Exception as e:
71
+ logger.error(f"Image OCR failed: {e}")
72
+ return ""
73
+ else: # Fallback text
74
+ try:
75
+ return content.decode("utf-8", errors="ignore")
76
+ except Exception as e:
77
+ logger.error(f"Text extract failed: {e}")
78
+ return ""
79
+
80
+ def clean_text(text: str) -> str:
81
+ t = re.sub(r"[ \t]+", " ", text)
82
+ t = re.sub(r"\n{3,}", "\n\n", t)
83
+ return t.strip()
84
+
85
+ def chunk_text(text: str, max_tokens: int = 400, overlap: int = 50) -> List[str]:
86
+ text = clean_text(text)
87
+ if not text:
88
+ return []
89
+ words = text.split()
90
+ chunks, start = [], 0
91
+ while start < len(words):
92
+ end = min(len(words), start + max_tokens)
93
+ chunk = " ".join(words[start:end]).strip()
94
+ if chunk:
95
+ chunks.append(chunk)
96
+ if end == len(words):
97
+ break
98
+ start = max(0, end - overlap)
99
+ return chunks
100
+
101
+ class RagIndex:
102
+ def __init__(self, index: faiss.IndexFlatIP, dim: int, chunks: List[str]):
103
+ self.index = index
104
+ self.dim = dim
105
+ self.chunks = chunks
106
+
107
+ def build_faiss_index(chunks: List[str]) -> RagIndex:
108
+ emb = _get_embedder()
109
+ vectors = emb.encode(chunks, convert_to_numpy=True, normalize_embeddings=True)
110
+ dim = vectors.shape[1]
111
+ index = faiss.IndexFlatIP(dim)
112
+ index.add(vectors)
113
+ return RagIndex(index=index, dim=dim, chunks=chunks)
114
+
115
+ def search(index: RagIndex, query: str, top_k: int = 6) -> List[Tuple[str, float]]:
116
+ emb = _get_embedder()
117
+ q = emb.encode([query], convert_to_numpy=True, normalize_embeddings=True)
118
+ D, I = index.index.search(q, top_k)
119
+ hits = []
120
+ for score, idx in zip(D[0], I[0]):
121
+ if idx == -1:
122
+ continue
123
+ hits.append((index.chunks[idx], float(score)))
124
+ return hits
125
+
126
+ def build_context_from_files(files: List[UploadFile], prompt: str, top_k: int = 6) -> str:
127
+ all_text = []
128
+ for file in files:
129
+ txt = extract_text(file)
130
+ if txt:
131
+ all_text.append(txt)
132
+ file.file.seek(0) # Reset
133
+ big_text = "\n\n".join(all_text)
134
+ chunks = chunk_text(big_text, max_tokens=450, overlap=80)
135
+ if not chunks:
136
+ return ""
137
+ idx = build_faiss_index(chunks)
138
+ hits = search(idx, prompt, top_k=top_k)
139
+ context_sections = [f"[DOC#{i} score={score:.3f}]\n{chunk}" for i, (chunk, score) in enumerate(hits, 1)]
140
+ return "\n\n".join(context_sections)
141
+
142
+ # Web search tool
143
+ def web_search(query: str) -> str:
144
+ try:
145
+ with DDGS() as ddgs:
146
+ results = [r for r in ddgs.text(query, max_results=5)]
147
+ sections = [f"[WEB#{i}] Title: {r['title']}\nSnippet: {r['body']}\nURL: {r['href']}" for i, r in enumerate(results, 1)]
148
+ return "\n\n".join(sections) if sections else "No results found."
149
+ except Exception as e:
150
+ logger.error(f"Web search failed: {e}")
151
+ return "Web search error."
app/rag/routes.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, Form
3
+ from fastapi.responses import StreamingResponse
4
+ from motor.motor_asyncio import AsyncIOMotorDatabase
5
+ from bson import ObjectId
6
+ from groq import Groq
7
+ import json
8
+ import logging
9
+ from datetime import datetime
10
+
11
+ from app.database.connection import get_db
12
+ from app.database.schemas import ConversationDB
13
+ from app.auth.routes import get_current_user
14
+ from app.auth.models import UserPublic
15
+ from app.rag.models import ALLOWED_MODELS, Message
16
+ from app.rag.rag_processor import build_context_from_files, web_search
17
+ from app.config import settings
18
+
19
+ router = APIRouter(tags=["RAG Chat"])
20
+ logger = logging.getLogger(__name__)
21
+
22
+ SYSTEM_PROMPT = """You are a helpful assistant. Use the provided context if relevant. If web search is enabled and you need up-to-date information, use the web_search tool. Reason step-by-step before deciding to use tools."""
23
+
24
+ WEB_SEARCH_TOOL = {
25
+ "type": "function",
26
+ "function": {
27
+ "name": "web_search",
28
+ "description": "Search the web using DuckDuckGo for up-to-date information.",
29
+ "parameters": {
30
+ "type": "object",
31
+ "properties": {"query": {"type": "string", "description": "The search query"}},
32
+ "required": ["query"],
33
+ },
34
+ },
35
+ }
36
+
37
+ @router.post("/conversations", status_code=status.HTTP_201_CREATED)
38
+ async def create_conversation(
39
+ current_user: UserPublic = Depends(get_current_user),
40
+ db: AsyncIOMotorDatabase = Depends(get_db),
41
+ ):
42
+ conv = ConversationDB(user_id=current_user.username)
43
+ result = await db.conversations.insert_one(conv.dict(exclude={"id"}))
44
+ conv.id = str(result.inserted_id)
45
+ return {"conversation_id": conv.id}
46
+
47
+ @router.get("/conversations/{conv_id}")
48
+ async def get_conversation(
49
+ conv_id: str,
50
+ current_user: UserPublic = Depends(get_current_user),
51
+ db: AsyncIOMotorDatabase = Depends(get_db),
52
+ ):
53
+ try:
54
+ oid = ObjectId(conv_id)
55
+ except:
56
+ raise HTTPException(status_code=400, detail="Invalid conversation ID")
57
+ conv = await db.conversations.find_one({"_id": oid, "user_id": current_user.username})
58
+ if not conv:
59
+ raise HTTPException(status_code=404, detail="Conversation not found")
60
+ conv["id"] = str(conv["_id"])
61
+ del conv["_id"]
62
+ return conv
63
+
64
+ @router.post("/conversations/{conv_id}/messages")
65
+ async def send_message(
66
+ conv_id: str,
67
+ model: str = Form(...),
68
+ enable_web_search: bool = Form(False),
69
+ message: str = Form(...),
70
+ files: List[UploadFile] = None,
71
+ current_user: UserPublic = Depends(get_current_user),
72
+ db: AsyncIOMotorDatabase = Depends(get_db),
73
+ ):
74
+ if model not in ALLOWED_MODELS:
75
+ raise HTTPException(status_code=400, detail="Invalid model")
76
+ try:
77
+ oid = ObjectId(conv_id)
78
+ except:
79
+ raise HTTPException(status_code=400, detail="Invalid conversation ID")
80
+ conv = await db.conversations.find_one({"_id": oid, "user_id": current_user.username})
81
+ if not conv:
82
+ raise HTTPException(status_code=404, detail="Conversation not found")
83
+
84
+ # Load messages
85
+ messages = [Message(**m) for m in conv.get("messages", [])]
86
+
87
+ # Build RAG context if files
88
+ rag_context = ""
89
+ if files:
90
+ rag_context = build_context_from_files(files, message)
91
+
92
+ # System prompt with context
93
+ system_msg = {"role": "system", "content": SYSTEM_PROMPT + (f"\n\nContext: {rag_context}" if rag_context else "")}
94
+
95
+ # Append user message
96
+ user_msg = Message(role="user", content=message)
97
+ messages.append(user_msg)
98
+
99
+ # Groq client
100
+ client = Groq(api_key=settings.groq_api_key)
101
+
102
+ # Tools if enabled
103
+ tools = [WEB_SEARCH_TOOL] if enable_web_search else None
104
+
105
+ # Tool loop for reasoning and multiple calls (up to 3 iterations)
106
+ chat_history = [
107
+ system_msg if isinstance(system_msg, dict) else system_msg.dict()
108
+ ] + [
109
+ m if isinstance(m, dict) else m.dict() for m in messages
110
+ ]
111
+ max_tool_loops = 3
112
+ for _ in range(max_tool_loops):
113
+ completion = client.chat.completions.create(
114
+ model=model,
115
+ messages=chat_history,
116
+ temperature=1,
117
+ max_tokens=8192,
118
+ top_p=1,
119
+ stream=False,
120
+ stop=None,
121
+ tools=tools,
122
+ )
123
+ choice = completion.choices[0].message
124
+ if not choice.tool_calls:
125
+ # No more tools, prepare to stream
126
+ break
127
+ for tool_call in choice.tool_calls:
128
+ if tool_call.function.name == "web_search":
129
+ args = json.loads(tool_call.function.arguments)
130
+ query = args["query"]
131
+ result = web_search(query)
132
+ tool_response = {
133
+ "role": "tool",
134
+ "tool_call_id": tool_call.id,
135
+ "name": "web_search",
136
+ "content": result,
137
+ }
138
+ chat_history.append(tool_response)
139
+ else:
140
+ logger.warning("Max tool loops reached")
141
+ raise HTTPException(status_code=500, detail="Too many tool calls")
142
+
143
+ # Final streaming call
144
+ completion = client.chat.completions.create(
145
+ model=model,
146
+ messages=chat_history,
147
+ temperature=1,
148
+ max_tokens=8192,
149
+ top_p=1,
150
+ stream=True,
151
+ stop=None,
152
+ )
153
+
154
+ # Stream response
155
+ async def generate():
156
+ response_content = ""
157
+ for chunk in completion:
158
+ content = chunk.choices[0].delta.content or ""
159
+ response_content += content
160
+ yield content
161
+ # Save to DB
162
+ messages.append(Message(role="assistant", content=response_content))
163
+ await db.conversations.update_one(
164
+ {"_id": oid},
165
+ {"$set": {"messages": [m.dict() for m in messages], "updated_at": datetime.utcnow()}}
166
+ )
167
+
168
+ return StreamingResponse(generate(), media_type="text/event-stream")
app/request.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """import requests
2
+ import json
3
+
4
+ # Base URL
5
+ BASE_URL = "http://localhost:8000"
6
+
7
+ # 1. Register a new user
8
+ register_data = {
9
+ "username": "testuser",
10
+ "email": "test@example.com",
11
+ "company": "Test Co",
12
+ "password": "securepassword123"
13
+ }
14
+ response = requests.post(f"{BASE_URL}/auth/register", json=register_data)
15
+ print("Register:", response.json())
16
+
17
+ # 2. Login (get access/refresh tokens)
18
+ login_data = {
19
+ "username": "testuser",
20
+ "password": "securepassword123"
21
+ }
22
+ response = requests.post(f"{BASE_URL}/auth/login", data=login_data)
23
+ tokens = response.json()
24
+ access_token = tokens["access_token"]
25
+ refresh_token = tokens["refresh_token"]
26
+ print("Login:", tokens)
27
+
28
+ # Headers for authenticated requests
29
+ headers = {"Authorization": f"Bearer {access_token}"}
30
+
31
+ # 3. Create a conversation
32
+ response = requests.post(f"{BASE_URL}/rag/conversations", headers=headers)
33
+ conv_id = response.json()["conversation_id"]
34
+ print("Conversation ID:", conv_id)
35
+
36
+ # 4. Send a message (with optional files, web search)
37
+ # Example: Text-only message
38
+ files = [] # Or: [('files', open('doc.pdf', 'rb'))] for uploads
39
+ data = {
40
+ "model": "llama-3.1-8b-instant",
41
+ "enable_web_search": True,
42
+ "message": "What is the capital of France?"
43
+ }
44
+ response = requests.post(
45
+ f"{BASE_URL}/rag/conversations/{conv_id}/messages",
46
+ headers=headers,
47
+ data=data,
48
+ files=files if files else None,
49
+ stream=True
50
+ )
51
+ for chunk in response.iter_content(chunk_size=1024):
52
+ if chunk:
53
+ print(chunk.decode(), end='', flush=True) # Streaming output
54
+
55
+ # 5. Get conversation history
56
+ response = requests.get(f"{BASE_URL}/rag/conversations/{conv_id}", headers=headers)
57
+ print("History:", response.json())
58
+
59
+ # 6. Refresh token
60
+ refresh_data = {"refresh_token": refresh_token}
61
+ response = requests.post(f"{BASE_URL}/auth/refresh", json=refresh_data)
62
+ new_tokens = response.json()
63
+ print("New Tokens:", new_tokens)
64
+
65
+ # 7. Logout
66
+ logout_data = {"refresh_token": refresh_token}
67
+ response = requests.post(f"{BASE_URL}/auth/logout", json=logout_data)
68
+ print("Logout:", response.json())"""
69
+ import requests
70
+ import json
71
+
72
+ # Base URL
73
+ BASE_URL = "http://localhost:8000"
74
+
75
+ # 1. Login (get access/refresh tokens) - Change credentials if needed
76
+ login_data = {
77
+ "username": "testuser", # Update if your username is different
78
+ "password": "securepassword123" # Update with your actual password
79
+ }
80
+ response = requests.post(f"{BASE_URL}/auth/login", data=login_data)
81
+
82
+ if response.status_code != 200:
83
+ print("Login Failed:", response.status_code, response.text)
84
+ else:
85
+ tokens = response.json()
86
+ access_token = tokens["access_token"]
87
+ refresh_token = tokens["refresh_token"]
88
+ print("Login Success:", tokens)
89
+
90
+ # Headers for authenticated requests
91
+ headers = {"Authorization": f"Bearer {access_token}"}
92
+
93
+ # 2. Create a conversation
94
+ response = requests.post(f"{BASE_URL}/rag/conversations", headers=headers)
95
+ if response.status_code == 201:
96
+ conv_id = response.json()["conversation_id"]
97
+ print("Conversation Created - ID:", conv_id)
98
+ else:
99
+ print("Failed to create conversation:", response.status_code, response.text)
100
+ conv_id = None
101
+
102
+ if conv_id:
103
+ # 3. Send a message (text-only example)
104
+ data = {
105
+ "model": "llama-3.1-8b-instant", # Change model if desired (from ALLOWED_MODELS)
106
+ "enable_web_search": "true", # "true" or "false" as string for form data
107
+ "message": "What is the capital of France?"
108
+ }
109
+ # Optional: Add files for document RAG
110
+ # files = [('files', open('your_document.pdf', 'rb'))]
111
+
112
+ response = requests.post(
113
+ f"{BASE_URL}/rag/conversations/{conv_id}/messages",
114
+ headers=headers,
115
+ data=data,
116
+ # files=files if 'files' in locals() else None,
117
+ stream=True
118
+ )
119
+
120
+ print("\n--- Assistant Response ---")
121
+ if response.status_code == 200:
122
+ for chunk in response.iter_content(chunk_size=1024, decode_unicode=True):
123
+ if chunk:
124
+ print(chunk, end='', flush=True)
125
+ print("\n--- End of Response ---")
126
+ else:
127
+ print("Message Send Failed:", response.status_code)
128
+ print("Response:", response.text)
129
+
130
+ # 4. Get conversation history
131
+ response = requests.get(f"{BASE_URL}/rag/conversations/{conv_id}", headers=headers)
132
+ print("\nConversation History:", json.dumps(response.json(), indent=2))
133
+
134
+ # 5. Refresh token (optional)
135
+ refresh_data = {"refresh_token": refresh_token}
136
+ response = requests.post(f"{BASE_URL}/auth/refresh", json=refresh_data)
137
+ print("Token Refresh:", response.json() if response.status_code == 200 else response.text)
138
+
139
+ # 6. Logout (optional)
140
+ logout_data = {"refresh_token": refresh_token}
141
+ response = requests.post(f"{BASE_URL}/auth/logout", json=logout_data)
142
+ print("Logout:", response.json())
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fastapi
2
+ uvicorn==0.24.0
3
+ pydantic==2.7.2
4
+ motor==3.3.2
5
+ passlib[argon2]==1.7.4
6
+ python-jose[cryptography]==3.3.0
7
+ sentence-transformers>=2.2.2,<3
8
+ faiss-cpu==1.7.4
9
+ PyPDF2==3.0.1
10
+ python-docx==1.1.0
11
+ duckduckgo-search==6.2.13
12
+ huggingface_hub>=0.17.0,<1.0
13
+ transformers<5,>=4.41.2
14
+ tokenizers<0.20,>=0.19.1
15
+ groq==0.11.0 # Or latest: pip install groq --upgrade
16
+ python-multipart==0.0.9
17
+ faiss-cpu==1.7.4