Spaces:
Sleeping
Sleeping
Upload 20 files
Browse files- .gitignore +12 -0
- Dockerfile +21 -0
- README.md +40 -10
- app/__init__.py +0 -0
- app/config.py +35 -0
- app/database.py +40 -0
- app/db_models.py +41 -0
- app/main.py +59 -0
- app/models.py +78 -0
- app/routes/__init__.py +0 -0
- app/routes/ask.py +39 -0
- app/routes/conversations.py +37 -0
- app/routes/health.py +127 -0
- app/routes/ingest.py +99 -0
- app/services/__init__.py +0 -0
- app/services/conversation_store.py +203 -0
- app/services/document_processor.py +92 -0
- app/services/rag_chain.py +204 -0
- app/services/vector_store.py +93 -0
- requirements.txt +21 -0
.gitignore
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.py[cod]
|
| 3 |
+
*$py.class
|
| 4 |
+
.env
|
| 5 |
+
venv/
|
| 6 |
+
uploads/
|
| 7 |
+
*.egg-info/
|
| 8 |
+
dist/
|
| 9 |
+
build/
|
| 10 |
+
.pytest_cache/
|
| 11 |
+
*.db
|
| 12 |
+
*.db-journal
|
Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.10-slim
|
| 2 |
+
|
| 3 |
+
WORKDIR /app
|
| 4 |
+
|
| 5 |
+
# Install dependencies
|
| 6 |
+
COPY requirements.txt .
|
| 7 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 8 |
+
|
| 9 |
+
# Copy application code
|
| 10 |
+
COPY . .
|
| 11 |
+
|
| 12 |
+
# Create necessary directories
|
| 13 |
+
RUN mkdir -p uploads
|
| 14 |
+
|
| 15 |
+
# Note: For production, use volumes to persist:
|
| 16 |
+
# - /app/uploads (uploaded documents)
|
| 17 |
+
# - /app/conversations.db (SQLite database)
|
| 18 |
+
|
| 19 |
+
EXPOSE 8000
|
| 20 |
+
|
| 21 |
+
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000"]
|
README.md
CHANGED
|
@@ -1,10 +1,40 @@
|
|
| 1 |
-
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend - AI RAG Chatbot API
|
| 2 |
+
|
| 3 |
+
## Setup
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
cd backend
|
| 7 |
+
python -m venv venv
|
| 8 |
+
venv\Scripts\activate # Windows
|
| 9 |
+
# source venv/bin/activate # Mac/Linux
|
| 10 |
+
|
| 11 |
+
pip install -r requirements.txt
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
## Configuration
|
| 15 |
+
|
| 16 |
+
Copy `.env.example` to `.env` and fill in your API keys:
|
| 17 |
+
|
| 18 |
+
```bash
|
| 19 |
+
cp .env.example .env
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
## Run
|
| 23 |
+
|
| 24 |
+
```bash
|
| 25 |
+
uvicorn app.main:app --reload --port 8000
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
## API Endpoints
|
| 29 |
+
|
| 30 |
+
| Method | Endpoint | Description |
|
| 31 |
+
| ------ | ------------------------------- | ---------------------------- |
|
| 32 |
+
| GET | `/health` | Health check |
|
| 33 |
+
| POST | `/api/ask` | Ask question (non-streaming) |
|
| 34 |
+
| POST | `/api/ask/stream` | Ask question (SSE streaming) |
|
| 35 |
+
| POST | `/api/ingest` | Upload & ingest document |
|
| 36 |
+
| POST | `/api/ingest/batch` | Batch upload documents |
|
| 37 |
+
| GET | `/api/conversations` | List conversations |
|
| 38 |
+
| GET | `/api/conversations/{id}` | Get conversation |
|
| 39 |
+
| DELETE | `/api/conversations/{id}` | Delete conversation |
|
| 40 |
+
| DELETE | `/api/conversations` | Clear all conversations |
|
app/__init__.py
ADDED
|
File without changes
|
app/config.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
from typing import List
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Settings(BaseSettings):
|
| 7 |
+
# Google Gemini
|
| 8 |
+
GOOGLE_API_KEY: str = ""
|
| 9 |
+
|
| 10 |
+
# Pinecone
|
| 11 |
+
PINECONE_API_KEY: str = ""
|
| 12 |
+
PINECONE_INDEX_NAME: str = "health-tech-kb"
|
| 13 |
+
|
| 14 |
+
# App settings
|
| 15 |
+
APP_ENV: str = "development"
|
| 16 |
+
CORS_ORIGINS: str = "http://localhost:3000"
|
| 17 |
+
UPLOAD_DIR: str = "./uploads"
|
| 18 |
+
CHUNK_SIZE: int = 1000
|
| 19 |
+
CHUNK_OVERLAP: int = 200
|
| 20 |
+
EMBEDDING_MODEL: str = "sentence-transformers/all-MiniLM-L6-v2"
|
| 21 |
+
LLM_MODEL: str = "gemini-2.5-flash"
|
| 22 |
+
|
| 23 |
+
@property
|
| 24 |
+
def cors_origins_list(self) -> List[str]:
|
| 25 |
+
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
|
| 26 |
+
|
| 27 |
+
class Config:
|
| 28 |
+
env_file = ".env"
|
| 29 |
+
env_file_encoding = "utf-8"
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
settings = Settings()
|
| 33 |
+
|
| 34 |
+
# Ensure upload directory exists
|
| 35 |
+
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
app/database.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Database configuration and session management."""
|
| 2 |
+
|
| 3 |
+
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine, async_sessionmaker
|
| 4 |
+
from sqlalchemy.orm import declarative_base
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# Database URL - SQLite for simplicity
|
| 8 |
+
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite+aiosqlite:///./conversations.db")
|
| 9 |
+
|
| 10 |
+
# Create async engine
|
| 11 |
+
engine = create_async_engine(
|
| 12 |
+
DATABASE_URL,
|
| 13 |
+
echo=False, # Set to True for SQL debugging
|
| 14 |
+
future=True,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Session factory
|
| 18 |
+
async_session_maker = async_sessionmaker(
|
| 19 |
+
engine,
|
| 20 |
+
class_=AsyncSession,
|
| 21 |
+
expire_on_commit=False,
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
# Base class for models
|
| 25 |
+
Base = declarative_base()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
async def get_db() -> AsyncSession:
|
| 29 |
+
"""Get database session."""
|
| 30 |
+
async with async_session_maker() as session:
|
| 31 |
+
try:
|
| 32 |
+
yield session
|
| 33 |
+
finally:
|
| 34 |
+
await session.close()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
async def init_db():
|
| 38 |
+
"""Initialize database tables."""
|
| 39 |
+
async with engine.begin() as conn:
|
| 40 |
+
await conn.run_sync(Base.metadata.create_all)
|
app/db_models.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SQLAlchemy database models for conversations."""
|
| 2 |
+
|
| 3 |
+
from sqlalchemy import Column, String, Integer, Text, ForeignKey, JSON
|
| 4 |
+
from sqlalchemy.orm import relationship
|
| 5 |
+
from app.database import Base
|
| 6 |
+
import uuid
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class DBConversation(Base):
|
| 10 |
+
"""Database model for conversations."""
|
| 11 |
+
|
| 12 |
+
__tablename__ = "conversations"
|
| 13 |
+
|
| 14 |
+
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
| 15 |
+
title = Column(String, default="New Chat")
|
| 16 |
+
created_at = Column(String, nullable=False)
|
| 17 |
+
updated_at = Column(String, nullable=False)
|
| 18 |
+
|
| 19 |
+
# Relationship to messages
|
| 20 |
+
messages = relationship(
|
| 21 |
+
"DBMessage",
|
| 22 |
+
back_populates="conversation",
|
| 23 |
+
cascade="all, delete-orphan",
|
| 24 |
+
order_by="DBMessage.id"
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class DBMessage(Base):
|
| 29 |
+
"""Database model for conversation messages."""
|
| 30 |
+
|
| 31 |
+
__tablename__ = "messages"
|
| 32 |
+
|
| 33 |
+
id = Column(Integer, primary_key=True, autoincrement=True)
|
| 34 |
+
conversation_id = Column(String, ForeignKey("conversations.id"), nullable=False)
|
| 35 |
+
role = Column(String, nullable=False) # 'user' or 'assistant'
|
| 36 |
+
content = Column(Text, nullable=False)
|
| 37 |
+
timestamp = Column(String, nullable=False)
|
| 38 |
+
sources = Column(JSON, default=list) # Stored as JSON array
|
| 39 |
+
|
| 40 |
+
# Relationship to conversation
|
| 41 |
+
conversation = relationship("DBConversation", back_populates="messages")
|
app/main.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from app.config import settings
|
| 4 |
+
from app.routes import ask, ingest, conversations, health
|
| 5 |
+
|
| 6 |
+
app = FastAPI(
|
| 7 |
+
title="Health-Tech AI RAG Chatbot",
|
| 8 |
+
description="AI-powered chatbot with Retrieval-Augmented Generation",
|
| 9 |
+
version="1.0.0",
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
# CORS middleware
|
| 13 |
+
app.add_middleware(
|
| 14 |
+
CORSMiddleware,
|
| 15 |
+
allow_origins=settings.cors_origins_list,
|
| 16 |
+
allow_credentials=True,
|
| 17 |
+
allow_methods=["*"],
|
| 18 |
+
allow_headers=["*"],
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Register routes
|
| 22 |
+
app.include_router(health.router, tags=["Health"])
|
| 23 |
+
app.include_router(ask.router, prefix="/api", tags=["Chat"])
|
| 24 |
+
app.include_router(ingest.router, prefix="/api", tags=["Ingestion"])
|
| 25 |
+
app.include_router(conversations.router, prefix="/api", tags=["Conversations"])
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@app.on_event("startup")
|
| 29 |
+
async def startup_event():
|
| 30 |
+
"""Initialize services on startup."""
|
| 31 |
+
from app.services.vector_store import vector_store_service
|
| 32 |
+
from app.database import init_db
|
| 33 |
+
import os
|
| 34 |
+
|
| 35 |
+
# Ensure uploads directory exists
|
| 36 |
+
os.makedirs(settings.UPLOAD_DIR, exist_ok=True)
|
| 37 |
+
upload_path = os.path.abspath(settings.UPLOAD_DIR)
|
| 38 |
+
print(f"📁 Upload directory: {upload_path}")
|
| 39 |
+
|
| 40 |
+
# Initialize database
|
| 41 |
+
await init_db()
|
| 42 |
+
db_path = os.path.abspath("conversations.db")
|
| 43 |
+
print(f"✅ Database initialized: {db_path}")
|
| 44 |
+
|
| 45 |
+
# Initialize vector store
|
| 46 |
+
await vector_store_service.initialize()
|
| 47 |
+
print("✅ Vector store initialized")
|
| 48 |
+
|
| 49 |
+
# Count existing documents
|
| 50 |
+
if os.path.exists(settings.UPLOAD_DIR):
|
| 51 |
+
doc_count = len([f for f in os.listdir(settings.UPLOAD_DIR) if os.path.isfile(os.path.join(settings.UPLOAD_DIR, f))])
|
| 52 |
+
print(f"📄 Existing documents: {doc_count}")
|
| 53 |
+
|
| 54 |
+
print(f"🚀 Backend ready on port 8000 (env: {settings.APP_ENV})")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
import uvicorn
|
| 59 |
+
uvicorn.run("app.main:app", host="0.0.0.0", port=8000, reload=True)
|
app/models.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
+
from typing import Optional, List
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import uuid
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class MessageRequest(BaseModel):
|
| 8 |
+
question: str = Field(..., min_length=1, max_length=2000, description="User question")
|
| 9 |
+
conversation_id: Optional[str] = Field(None, description="Conversation ID for context")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SourceDocument(BaseModel):
|
| 13 |
+
content: str
|
| 14 |
+
source: str
|
| 15 |
+
page: Optional[int] = None
|
| 16 |
+
chunk_id: Optional[str] = None
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class MessageResponse(BaseModel):
|
| 20 |
+
answer: str
|
| 21 |
+
conversation_id: str
|
| 22 |
+
sources: List[SourceDocument] = []
|
| 23 |
+
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat() + "Z")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class IngestRequest(BaseModel):
|
| 27 |
+
metadata: Optional[dict] = None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class IngestResponse(BaseModel):
|
| 31 |
+
message: str
|
| 32 |
+
documents_processed: int
|
| 33 |
+
chunks_created: int
|
| 34 |
+
filename: str
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ConversationMessage(BaseModel):
|
| 38 |
+
role: str # "user" or "assistant"
|
| 39 |
+
content: str
|
| 40 |
+
sources: List[SourceDocument] = []
|
| 41 |
+
timestamp: str = Field(default_factory=lambda: datetime.utcnow().isoformat() + "Z")
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Conversation(BaseModel):
|
| 45 |
+
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
| 46 |
+
title: str = "New Conversation"
|
| 47 |
+
messages: List[ConversationMessage] = []
|
| 48 |
+
created_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat() + "Z")
|
| 49 |
+
updated_at: str = Field(default_factory=lambda: datetime.utcnow().isoformat() + "Z")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class ConversationListItem(BaseModel):
|
| 53 |
+
id: str
|
| 54 |
+
title: str
|
| 55 |
+
message_count: int
|
| 56 |
+
created_at: str
|
| 57 |
+
updated_at: str
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class HealthResponse(BaseModel):
|
| 61 |
+
status: str = "healthy"
|
| 62 |
+
version: str = "1.0.0"
|
| 63 |
+
environment: str = "development"
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
class VectorStatsResponse(BaseModel):
|
| 67 |
+
total_vectors: int = 0
|
| 68 |
+
total_documents: int = 0
|
| 69 |
+
index_name: str = ""
|
| 70 |
+
dimension: int = 384
|
| 71 |
+
error: Optional[str] = None
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class DocumentInfo(BaseModel):
|
| 75 |
+
name: str
|
| 76 |
+
size: int # bytes
|
| 77 |
+
type: str # PDF, Markdown, etc.
|
| 78 |
+
uploaded_at: str
|
app/routes/__init__.py
ADDED
|
File without changes
|
app/routes/ask.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from fastapi.responses import StreamingResponse
|
| 3 |
+
from app.models import MessageRequest, MessageResponse
|
| 4 |
+
from app.services.rag_chain import rag_service
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@router.post("/ask", response_model=MessageResponse)
|
| 10 |
+
async def ask_question(request: MessageRequest):
|
| 11 |
+
"""Ask a question to the RAG chatbot (non-streaming)."""
|
| 12 |
+
try:
|
| 13 |
+
result = await rag_service.ask(
|
| 14 |
+
question=request.question,
|
| 15 |
+
conversation_id=request.conversation_id,
|
| 16 |
+
)
|
| 17 |
+
return MessageResponse(**result)
|
| 18 |
+
except Exception as e:
|
| 19 |
+
raise HTTPException(status_code=500, detail=str(e))
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@router.post("/ask/stream")
|
| 23 |
+
async def ask_question_stream(request: MessageRequest):
|
| 24 |
+
"""Ask a question to the RAG chatbot (streaming via SSE)."""
|
| 25 |
+
try:
|
| 26 |
+
return StreamingResponse(
|
| 27 |
+
rag_service.ask_stream(
|
| 28 |
+
question=request.question,
|
| 29 |
+
conversation_id=request.conversation_id,
|
| 30 |
+
),
|
| 31 |
+
media_type="text/event-stream",
|
| 32 |
+
headers={
|
| 33 |
+
"Cache-Control": "no-cache",
|
| 34 |
+
"Connection": "keep-alive",
|
| 35 |
+
"X-Accel-Buffering": "no",
|
| 36 |
+
},
|
| 37 |
+
)
|
| 38 |
+
except Exception as e:
|
| 39 |
+
raise HTTPException(status_code=500, detail=str(e))
|
app/routes/conversations.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from typing import List
|
| 3 |
+
from app.models import Conversation, ConversationListItem
|
| 4 |
+
from app.services.conversation_store import conversation_store
|
| 5 |
+
|
| 6 |
+
router = APIRouter()
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@router.get("/conversations", response_model=List[ConversationListItem])
|
| 10 |
+
async def list_conversations():
|
| 11 |
+
"""List all conversations."""
|
| 12 |
+
return await conversation_store.list_conversations()
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.get("/conversations/{conversation_id}", response_model=Conversation)
|
| 16 |
+
async def get_conversation(conversation_id: str):
|
| 17 |
+
"""Get a specific conversation with messages."""
|
| 18 |
+
conversation = await conversation_store.get_conversation(conversation_id)
|
| 19 |
+
if not conversation:
|
| 20 |
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 21 |
+
return conversation
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@router.delete("/conversations/{conversation_id}")
|
| 25 |
+
async def delete_conversation(conversation_id: str):
|
| 26 |
+
"""Delete a specific conversation."""
|
| 27 |
+
success = await conversation_store.delete_conversation(conversation_id)
|
| 28 |
+
if not success:
|
| 29 |
+
raise HTTPException(status_code=404, detail="Conversation not found")
|
| 30 |
+
return {"message": "Conversation deleted"}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@router.delete("/conversations")
|
| 34 |
+
async def clear_all_conversations():
|
| 35 |
+
"""Clear all conversations."""
|
| 36 |
+
await conversation_store.clear_all()
|
| 37 |
+
return {"message": "All conversations cleared"}
|
app/routes/health.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import APIRouter, HTTPException
|
| 2 |
+
from fastapi.responses import FileResponse
|
| 3 |
+
from app.config import settings
|
| 4 |
+
from app.models import HealthResponse, VectorStatsResponse, DocumentInfo
|
| 5 |
+
from app.services.vector_store import vector_store_service
|
| 6 |
+
from typing import List
|
| 7 |
+
import os
|
| 8 |
+
from datetime import datetime
|
| 9 |
+
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@router.get("/health", response_model=HealthResponse)
|
| 14 |
+
async def health_check():
|
| 15 |
+
"""Health check endpoint."""
|
| 16 |
+
return HealthResponse(
|
| 17 |
+
status="healthy",
|
| 18 |
+
version="1.0.0",
|
| 19 |
+
environment=settings.APP_ENV,
|
| 20 |
+
)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@router.get("/api/stats/vectors", response_model=VectorStatsResponse)
|
| 24 |
+
async def get_vector_stats():
|
| 25 |
+
"""Get vector store statistics."""
|
| 26 |
+
try:
|
| 27 |
+
# Initialize if needed
|
| 28 |
+
await vector_store_service.initialize()
|
| 29 |
+
|
| 30 |
+
# Get stats from Pinecone
|
| 31 |
+
index = vector_store_service.pc.Index(settings.PINECONE_INDEX_NAME)
|
| 32 |
+
stats = index.describe_index_stats()
|
| 33 |
+
|
| 34 |
+
# Count documents in uploads folder
|
| 35 |
+
upload_dir = settings.UPLOAD_DIR
|
| 36 |
+
documents = 0
|
| 37 |
+
if os.path.exists(upload_dir):
|
| 38 |
+
documents = len([f for f in os.listdir(upload_dir)
|
| 39 |
+
if os.path.isfile(os.path.join(upload_dir, f))])
|
| 40 |
+
|
| 41 |
+
return VectorStatsResponse(
|
| 42 |
+
total_vectors=stats.total_vector_count,
|
| 43 |
+
total_documents=documents,
|
| 44 |
+
index_name=settings.PINECONE_INDEX_NAME,
|
| 45 |
+
dimension=stats.dimension,
|
| 46 |
+
)
|
| 47 |
+
except Exception as e:
|
| 48 |
+
return VectorStatsResponse(
|
| 49 |
+
total_vectors=0,
|
| 50 |
+
total_documents=0,
|
| 51 |
+
index_name=settings.PINECONE_INDEX_NAME,
|
| 52 |
+
dimension=384,
|
| 53 |
+
error=str(e),
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@router.get("/api/documents", response_model=List[DocumentInfo])
|
| 58 |
+
async def list_documents():
|
| 59 |
+
"""List all uploaded documents with metadata."""
|
| 60 |
+
upload_dir = settings.UPLOAD_DIR
|
| 61 |
+
documents = []
|
| 62 |
+
|
| 63 |
+
if not os.path.exists(upload_dir):
|
| 64 |
+
return documents
|
| 65 |
+
|
| 66 |
+
for filename in os.listdir(upload_dir):
|
| 67 |
+
filepath = os.path.join(upload_dir, filename)
|
| 68 |
+
if os.path.isfile(filepath):
|
| 69 |
+
stat = os.stat(filepath)
|
| 70 |
+
ext = os.path.splitext(filename)[1].lower()
|
| 71 |
+
|
| 72 |
+
# Get file type label
|
| 73 |
+
type_map = {
|
| 74 |
+
'.pdf': 'PDF',
|
| 75 |
+
'.md': 'Markdown',
|
| 76 |
+
'.txt': 'Text',
|
| 77 |
+
'.csv': 'CSV',
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
documents.append(DocumentInfo(
|
| 81 |
+
name=filename,
|
| 82 |
+
size=stat.st_size,
|
| 83 |
+
type=type_map.get(ext, ext.upper()),
|
| 84 |
+
uploaded_at=datetime.fromtimestamp(stat.st_mtime).isoformat(),
|
| 85 |
+
))
|
| 86 |
+
|
| 87 |
+
# Sort by upload date descending
|
| 88 |
+
documents.sort(key=lambda x: x.uploaded_at, reverse=True)
|
| 89 |
+
return documents
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@router.get("/api/documents/{filename}")
|
| 93 |
+
async def download_document(filename: str):
|
| 94 |
+
"""Download a specific document."""
|
| 95 |
+
# Security: prevent directory traversal
|
| 96 |
+
if ".." in filename or "/" in filename or "\\" in filename:
|
| 97 |
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
| 98 |
+
|
| 99 |
+
filepath = os.path.join(settings.UPLOAD_DIR, filename)
|
| 100 |
+
|
| 101 |
+
if not os.path.exists(filepath):
|
| 102 |
+
raise HTTPException(status_code=404, detail="Document not found")
|
| 103 |
+
|
| 104 |
+
return FileResponse(
|
| 105 |
+
path=filepath,
|
| 106 |
+
filename=filename,
|
| 107 |
+
media_type="application/octet-stream"
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
@router.delete("/api/documents/{filename}")
|
| 112 |
+
async def delete_document(filename: str):
|
| 113 |
+
"""Delete a specific document."""
|
| 114 |
+
# Security: prevent directory traversal
|
| 115 |
+
if ".." in filename or "/" in filename or "\\" in filename:
|
| 116 |
+
raise HTTPException(status_code=400, detail="Invalid filename")
|
| 117 |
+
|
| 118 |
+
filepath = os.path.join(settings.UPLOAD_DIR, filename)
|
| 119 |
+
|
| 120 |
+
if not os.path.exists(filepath):
|
| 121 |
+
raise HTTPException(status_code=404, detail="Document not found")
|
| 122 |
+
|
| 123 |
+
try:
|
| 124 |
+
os.remove(filepath)
|
| 125 |
+
return {"message": f"Document '{filename}' deleted successfully"}
|
| 126 |
+
except Exception as e:
|
| 127 |
+
raise HTTPException(status_code=500, detail=f"Failed to delete document: {str(e)}")
|
app/routes/ingest.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import aiofiles
|
| 3 |
+
from fastapi import APIRouter, HTTPException, UploadFile, File, Form
|
| 4 |
+
from typing import Optional
|
| 5 |
+
from app.models import IngestResponse
|
| 6 |
+
from app.services.document_processor import document_processor
|
| 7 |
+
from app.services.vector_store import vector_store_service
|
| 8 |
+
from app.config import settings
|
| 9 |
+
|
| 10 |
+
router = APIRouter()
|
| 11 |
+
|
| 12 |
+
ALLOWED_EXTENSIONS = {".pdf", ".csv", ".md", ".txt"}
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@router.post("/ingest", response_model=IngestResponse)
|
| 16 |
+
async def ingest_document(
|
| 17 |
+
file: UploadFile = File(...),
|
| 18 |
+
metadata: Optional[str] = Form(None),
|
| 19 |
+
):
|
| 20 |
+
"""Upload and ingest a document into the vector store."""
|
| 21 |
+
# Validate file extension
|
| 22 |
+
ext = os.path.splitext(file.filename)[1].lower()
|
| 23 |
+
if ext not in ALLOWED_EXTENSIONS:
|
| 24 |
+
raise HTTPException(
|
| 25 |
+
status_code=400,
|
| 26 |
+
detail=f"Unsupported file type: {ext}. Allowed: {', '.join(ALLOWED_EXTENSIONS)}",
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Save uploaded file
|
| 30 |
+
file_path = os.path.join(settings.UPLOAD_DIR, file.filename)
|
| 31 |
+
try:
|
| 32 |
+
async with aiofiles.open(file_path, "wb") as f:
|
| 33 |
+
content = await file.read()
|
| 34 |
+
await f.write(content)
|
| 35 |
+
except Exception as e:
|
| 36 |
+
raise HTTPException(status_code=500, detail=f"Failed to save file: {str(e)}")
|
| 37 |
+
|
| 38 |
+
# Parse metadata
|
| 39 |
+
meta = None
|
| 40 |
+
if metadata:
|
| 41 |
+
import json
|
| 42 |
+
|
| 43 |
+
try:
|
| 44 |
+
meta = json.loads(metadata)
|
| 45 |
+
except json.JSONDecodeError:
|
| 46 |
+
meta = {"description": metadata}
|
| 47 |
+
|
| 48 |
+
# Process and ingest
|
| 49 |
+
try:
|
| 50 |
+
chunks = await document_processor.process_file(file_path, meta)
|
| 51 |
+
num_added = await vector_store_service.add_documents(chunks)
|
| 52 |
+
|
| 53 |
+
return IngestResponse(
|
| 54 |
+
message=f"Successfully ingested '{file.filename}'",
|
| 55 |
+
documents_processed=1,
|
| 56 |
+
chunks_created=num_added,
|
| 57 |
+
filename=file.filename,
|
| 58 |
+
)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
raise HTTPException(
|
| 61 |
+
status_code=500, detail=f"Failed to process document: {str(e)}"
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
@router.post("/ingest/batch")
|
| 66 |
+
async def ingest_batch(files: list[UploadFile] = File(...)):
|
| 67 |
+
"""Upload and ingest multiple documents."""
|
| 68 |
+
results = []
|
| 69 |
+
|
| 70 |
+
for file in files:
|
| 71 |
+
ext = os.path.splitext(file.filename)[1].lower()
|
| 72 |
+
if ext not in ALLOWED_EXTENSIONS:
|
| 73 |
+
results.append(
|
| 74 |
+
{"filename": file.filename, "status": "skipped", "reason": f"Unsupported type: {ext}"}
|
| 75 |
+
)
|
| 76 |
+
continue
|
| 77 |
+
|
| 78 |
+
file_path = os.path.join(settings.UPLOAD_DIR, file.filename)
|
| 79 |
+
try:
|
| 80 |
+
async with aiofiles.open(file_path, "wb") as f:
|
| 81 |
+
content = await file.read()
|
| 82 |
+
await f.write(content)
|
| 83 |
+
|
| 84 |
+
chunks = await document_processor.process_file(file_path)
|
| 85 |
+
num_added = await vector_store_service.add_documents(chunks)
|
| 86 |
+
|
| 87 |
+
results.append(
|
| 88 |
+
{
|
| 89 |
+
"filename": file.filename,
|
| 90 |
+
"status": "success",
|
| 91 |
+
"chunks_created": num_added,
|
| 92 |
+
}
|
| 93 |
+
)
|
| 94 |
+
except Exception as e:
|
| 95 |
+
results.append(
|
| 96 |
+
{"filename": file.filename, "status": "error", "reason": str(e)}
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
return {"results": results, "total_files": len(files)}
|
app/services/__init__.py
ADDED
|
File without changes
|
app/services/conversation_store.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Database-backed conversation store."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from datetime import datetime
|
| 5 |
+
from sqlalchemy import select, desc
|
| 6 |
+
from sqlalchemy.orm import selectinload
|
| 7 |
+
from app.models import Conversation, ConversationMessage, ConversationListItem
|
| 8 |
+
from app.database import async_session_maker
|
| 9 |
+
from app.db_models import DBConversation, DBMessage
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ConversationStore:
|
| 13 |
+
"""Database-backed conversation storage."""
|
| 14 |
+
|
| 15 |
+
async def get_or_create(self, conversation_id: Optional[str] = None) -> Conversation:
|
| 16 |
+
"""Get existing conversation or create a new one."""
|
| 17 |
+
async with async_session_maker() as session:
|
| 18 |
+
if conversation_id:
|
| 19 |
+
# Try to get existing conversation
|
| 20 |
+
result = await session.execute(
|
| 21 |
+
select(DBConversation)
|
| 22 |
+
.options(selectinload(DBConversation.messages))
|
| 23 |
+
.where(DBConversation.id == conversation_id)
|
| 24 |
+
)
|
| 25 |
+
db_conv = result.scalar_one_or_none()
|
| 26 |
+
|
| 27 |
+
if db_conv:
|
| 28 |
+
# Convert DB model to Pydantic model
|
| 29 |
+
return Conversation(
|
| 30 |
+
id=db_conv.id,
|
| 31 |
+
title=db_conv.title,
|
| 32 |
+
messages=[
|
| 33 |
+
ConversationMessage(
|
| 34 |
+
role=msg.role,
|
| 35 |
+
content=msg.content,
|
| 36 |
+
timestamp=msg.timestamp,
|
| 37 |
+
sources=msg.sources or [],
|
| 38 |
+
)
|
| 39 |
+
for msg in db_conv.messages
|
| 40 |
+
],
|
| 41 |
+
created_at=db_conv.created_at,
|
| 42 |
+
updated_at=db_conv.updated_at,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Create new conversation
|
| 46 |
+
now = datetime.utcnow().isoformat() + "Z"
|
| 47 |
+
db_conv = DBConversation(
|
| 48 |
+
title="New Chat",
|
| 49 |
+
created_at=now,
|
| 50 |
+
updated_at=now,
|
| 51 |
+
)
|
| 52 |
+
session.add(db_conv)
|
| 53 |
+
await session.commit()
|
| 54 |
+
await session.refresh(db_conv)
|
| 55 |
+
|
| 56 |
+
return Conversation(
|
| 57 |
+
id=db_conv.id,
|
| 58 |
+
title=db_conv.title,
|
| 59 |
+
messages=[],
|
| 60 |
+
created_at=db_conv.created_at,
|
| 61 |
+
updated_at=db_conv.updated_at,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
async def add_message(
|
| 65 |
+
self,
|
| 66 |
+
conversation_id: str,
|
| 67 |
+
role: str,
|
| 68 |
+
content: str,
|
| 69 |
+
sources: list = None,
|
| 70 |
+
) -> ConversationMessage:
|
| 71 |
+
"""Add a message to a conversation."""
|
| 72 |
+
async with async_session_maker() as session:
|
| 73 |
+
# Get conversation
|
| 74 |
+
result = await session.execute(
|
| 75 |
+
select(DBConversation).where(DBConversation.id == conversation_id)
|
| 76 |
+
)
|
| 77 |
+
db_conv = result.scalar_one_or_none()
|
| 78 |
+
|
| 79 |
+
if not db_conv:
|
| 80 |
+
raise ValueError(f"Conversation {conversation_id} not found")
|
| 81 |
+
|
| 82 |
+
# Create message
|
| 83 |
+
now = datetime.utcnow().isoformat() + "Z"
|
| 84 |
+
db_message = DBMessage(
|
| 85 |
+
conversation_id=conversation_id,
|
| 86 |
+
role=role,
|
| 87 |
+
content=content,
|
| 88 |
+
timestamp=now,
|
| 89 |
+
sources=sources or [],
|
| 90 |
+
)
|
| 91 |
+
session.add(db_message)
|
| 92 |
+
|
| 93 |
+
# Update conversation
|
| 94 |
+
db_conv.updated_at = now
|
| 95 |
+
|
| 96 |
+
# Update title from first user message
|
| 97 |
+
if role == "user" and db_conv.title == "New Chat":
|
| 98 |
+
db_conv.title = content[:80] + ("..." if len(content) > 80 else "")
|
| 99 |
+
|
| 100 |
+
await session.commit()
|
| 101 |
+
|
| 102 |
+
return ConversationMessage(
|
| 103 |
+
role=role,
|
| 104 |
+
content=content,
|
| 105 |
+
timestamp=now,
|
| 106 |
+
sources=sources or [],
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
async def get_conversation(self, conversation_id: str) -> Optional[Conversation]:
|
| 110 |
+
"""Get a conversation by ID."""
|
| 111 |
+
async with async_session_maker() as session:
|
| 112 |
+
result = await session.execute(
|
| 113 |
+
select(DBConversation)
|
| 114 |
+
.options(selectinload(DBConversation.messages))
|
| 115 |
+
.where(DBConversation.id == conversation_id)
|
| 116 |
+
)
|
| 117 |
+
db_conv = result.scalar_one_or_none()
|
| 118 |
+
|
| 119 |
+
if not db_conv:
|
| 120 |
+
return None
|
| 121 |
+
|
| 122 |
+
return Conversation(
|
| 123 |
+
id=db_conv.id,
|
| 124 |
+
title=db_conv.title,
|
| 125 |
+
messages=[
|
| 126 |
+
ConversationMessage(
|
| 127 |
+
role=msg.role,
|
| 128 |
+
content=msg.content,
|
| 129 |
+
timestamp=msg.timestamp,
|
| 130 |
+
sources=msg.sources or [],
|
| 131 |
+
)
|
| 132 |
+
for msg in db_conv.messages
|
| 133 |
+
],
|
| 134 |
+
created_at=db_conv.created_at,
|
| 135 |
+
updated_at=db_conv.updated_at,
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
async def get_history(self, conversation_id: str, limit: int = 10) -> List[dict]:
|
| 139 |
+
"""Get recent message history for a conversation as list of dicts."""
|
| 140 |
+
async with async_session_maker() as session:
|
| 141 |
+
result = await session.execute(
|
| 142 |
+
select(DBMessage)
|
| 143 |
+
.where(DBMessage.conversation_id == conversation_id)
|
| 144 |
+
.order_by(desc(DBMessage.id))
|
| 145 |
+
.limit(limit)
|
| 146 |
+
)
|
| 147 |
+
messages = result.scalars().all()
|
| 148 |
+
|
| 149 |
+
# Reverse to get chronological order
|
| 150 |
+
return [
|
| 151 |
+
{"role": msg.role, "content": msg.content}
|
| 152 |
+
for msg in reversed(messages)
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
async def list_conversations(self) -> List[ConversationListItem]:
|
| 156 |
+
"""List all conversations."""
|
| 157 |
+
async with async_session_maker() as session:
|
| 158 |
+
result = await session.execute(
|
| 159 |
+
select(DBConversation)
|
| 160 |
+
.options(selectinload(DBConversation.messages))
|
| 161 |
+
.order_by(desc(DBConversation.updated_at))
|
| 162 |
+
)
|
| 163 |
+
conversations = result.scalars().all()
|
| 164 |
+
|
| 165 |
+
return [
|
| 166 |
+
ConversationListItem(
|
| 167 |
+
id=conv.id,
|
| 168 |
+
title=conv.title,
|
| 169 |
+
message_count=len(conv.messages),
|
| 170 |
+
created_at=conv.created_at,
|
| 171 |
+
updated_at=conv.updated_at,
|
| 172 |
+
)
|
| 173 |
+
for conv in conversations
|
| 174 |
+
]
|
| 175 |
+
|
| 176 |
+
async def delete_conversation(self, conversation_id: str) -> bool:
|
| 177 |
+
"""Delete a conversation."""
|
| 178 |
+
async with async_session_maker() as session:
|
| 179 |
+
result = await session.execute(
|
| 180 |
+
select(DBConversation).where(DBConversation.id == conversation_id)
|
| 181 |
+
)
|
| 182 |
+
db_conv = result.scalar_one_or_none()
|
| 183 |
+
|
| 184 |
+
if db_conv:
|
| 185 |
+
await session.delete(db_conv)
|
| 186 |
+
await session.commit()
|
| 187 |
+
return True
|
| 188 |
+
return False
|
| 189 |
+
|
| 190 |
+
async def clear_all(self):
|
| 191 |
+
"""Clear all conversations."""
|
| 192 |
+
async with async_session_maker() as session:
|
| 193 |
+
await session.execute(select(DBConversation))
|
| 194 |
+
result = await session.execute(select(DBConversation))
|
| 195 |
+
conversations = result.scalars().all()
|
| 196 |
+
|
| 197 |
+
for conv in conversations:
|
| 198 |
+
await session.delete(conv)
|
| 199 |
+
|
| 200 |
+
await session.commit()
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
conversation_store = ConversationStore()
|
app/services/document_processor.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Document processing service for PDF, CSV, and Markdown files."""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Optional
|
| 5 |
+
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
| 6 |
+
from langchain_community.document_loaders import (
|
| 7 |
+
PyPDFLoader,
|
| 8 |
+
CSVLoader,
|
| 9 |
+
TextLoader,
|
| 10 |
+
)
|
| 11 |
+
from langchain.schema import Document
|
| 12 |
+
from app.config import settings
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class DocumentProcessor:
|
| 16 |
+
"""Handles loading and chunking of various document types."""
|
| 17 |
+
|
| 18 |
+
def __init__(self):
|
| 19 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 20 |
+
chunk_size=settings.CHUNK_SIZE,
|
| 21 |
+
chunk_overlap=settings.CHUNK_OVERLAP,
|
| 22 |
+
length_function=len,
|
| 23 |
+
separators=["\n\n", "\n", ". ", " ", ""],
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
self.supported_extensions = {
|
| 27 |
+
".pdf": self._load_pdf,
|
| 28 |
+
".csv": self._load_csv,
|
| 29 |
+
".md": self._load_markdown,
|
| 30 |
+
".txt": self._load_text,
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
def get_supported_extensions(self) -> List[str]:
|
| 34 |
+
return list(self.supported_extensions.keys())
|
| 35 |
+
|
| 36 |
+
async def process_file(
|
| 37 |
+
self, file_path: str, metadata: Optional[dict] = None
|
| 38 |
+
) -> List[Document]:
|
| 39 |
+
"""Process a file and return chunked documents with metadata."""
|
| 40 |
+
ext = os.path.splitext(file_path)[1].lower()
|
| 41 |
+
|
| 42 |
+
if ext not in self.supported_extensions:
|
| 43 |
+
raise ValueError(
|
| 44 |
+
f"Unsupported file type: {ext}. "
|
| 45 |
+
f"Supported: {', '.join(self.supported_extensions.keys())}"
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Load documents
|
| 49 |
+
loader_fn = self.supported_extensions[ext]
|
| 50 |
+
documents = loader_fn(file_path)
|
| 51 |
+
|
| 52 |
+
# Add custom metadata
|
| 53 |
+
filename = os.path.basename(file_path)
|
| 54 |
+
for doc in documents:
|
| 55 |
+
doc.metadata["source"] = filename
|
| 56 |
+
doc.metadata["file_type"] = ext
|
| 57 |
+
if metadata:
|
| 58 |
+
doc.metadata.update(metadata)
|
| 59 |
+
|
| 60 |
+
# Split into chunks
|
| 61 |
+
chunks = self.text_splitter.split_documents(documents)
|
| 62 |
+
|
| 63 |
+
# Add chunk IDs
|
| 64 |
+
for i, chunk in enumerate(chunks):
|
| 65 |
+
chunk.metadata["chunk_id"] = f"{filename}_chunk_{i}"
|
| 66 |
+
chunk.metadata["chunk_index"] = i
|
| 67 |
+
chunk.metadata["total_chunks"] = len(chunks)
|
| 68 |
+
|
| 69 |
+
return chunks
|
| 70 |
+
|
| 71 |
+
def _load_pdf(self, file_path: str) -> List[Document]:
|
| 72 |
+
"""Load PDF file."""
|
| 73 |
+
loader = PyPDFLoader(file_path)
|
| 74 |
+
return loader.load()
|
| 75 |
+
|
| 76 |
+
def _load_csv(self, file_path: str) -> List[Document]:
|
| 77 |
+
"""Load CSV file."""
|
| 78 |
+
loader = CSVLoader(file_path, encoding="utf-8")
|
| 79 |
+
return loader.load()
|
| 80 |
+
|
| 81 |
+
def _load_markdown(self, file_path: str) -> List[Document]:
|
| 82 |
+
"""Load Markdown file."""
|
| 83 |
+
loader = TextLoader(file_path, encoding="utf-8")
|
| 84 |
+
return loader.load()
|
| 85 |
+
|
| 86 |
+
def _load_text(self, file_path: str) -> List[Document]:
|
| 87 |
+
"""Load plain text file."""
|
| 88 |
+
loader = TextLoader(file_path, encoding="utf-8")
|
| 89 |
+
return loader.load()
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
document_processor = DocumentProcessor()
|
app/services/rag_chain.py
ADDED
|
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RAG chain service using LangChain with Gemini."""
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from typing import AsyncGenerator, List, Optional
|
| 5 |
+
from langchain_google_genai import ChatGoogleGenerativeAI
|
| 6 |
+
from langchain.prompts import ChatPromptTemplate, MessagesPlaceholder
|
| 7 |
+
from langchain.schema import Document, HumanMessage, AIMessage
|
| 8 |
+
from app.config import settings
|
| 9 |
+
from app.services.vector_store import vector_store_service
|
| 10 |
+
from app.services.conversation_store import conversation_store
|
| 11 |
+
from app.models import SourceDocument
|
| 12 |
+
|
| 13 |
+
import asyncio
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
SYSTEM_PROMPT = """You are an intelligent AI assistant for a health-tech knowledge base. \
|
| 17 |
+
Your role is to answer questions accurately based on the provided context from our documents.
|
| 18 |
+
|
| 19 |
+
INSTRUCTIONS:
|
| 20 |
+
- Answer questions based ONLY on the provided context below.
|
| 21 |
+
- When listing items (like "pillars", "steps", "types"), make sure to include ALL items found in the context.
|
| 22 |
+
- If the context doesn't contain relevant information, say "I don't have enough information in the knowledge base to answer that question."
|
| 23 |
+
- Be thorough and complete - don't skip information that's present in the context.
|
| 24 |
+
- When referencing information, mention the source document when possible.
|
| 25 |
+
- Format your answers with proper markdown for readability.
|
| 26 |
+
- If asked about topics outside the knowledge base, politely redirect to relevant topics.
|
| 27 |
+
|
| 28 |
+
CONTEXT FROM KNOWLEDGE BASE:
|
| 29 |
+
{context}
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class RAGService:
|
| 34 |
+
"""Retrieval-Augmented Generation service using Gemini."""
|
| 35 |
+
|
| 36 |
+
def __init__(self):
|
| 37 |
+
self.llm: Optional[ChatGoogleGenerativeAI] = None
|
| 38 |
+
|
| 39 |
+
def _get_llm(self, streaming: bool = False) -> ChatGoogleGenerativeAI:
|
| 40 |
+
"""Get or create Gemini LLM instance."""
|
| 41 |
+
return ChatGoogleGenerativeAI(
|
| 42 |
+
model=settings.LLM_MODEL,
|
| 43 |
+
google_api_key=settings.GOOGLE_API_KEY,
|
| 44 |
+
temperature=0.3,
|
| 45 |
+
streaming=streaming,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def _build_prompt(self) -> ChatPromptTemplate:
|
| 49 |
+
"""Build the RAG prompt template."""
|
| 50 |
+
return ChatPromptTemplate.from_messages(
|
| 51 |
+
[
|
| 52 |
+
("system", SYSTEM_PROMPT),
|
| 53 |
+
MessagesPlaceholder(variable_name="chat_history"),
|
| 54 |
+
("human", "{question}"),
|
| 55 |
+
]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
def _format_context(self, documents: List[Document]) -> str:
|
| 59 |
+
"""Format retrieved documents into context string."""
|
| 60 |
+
if not documents:
|
| 61 |
+
return "No relevant documents found."
|
| 62 |
+
|
| 63 |
+
context_parts = []
|
| 64 |
+
for i, doc in enumerate(documents, 1):
|
| 65 |
+
source = doc.metadata.get("source", "Unknown")
|
| 66 |
+
page = doc.metadata.get("page", "")
|
| 67 |
+
page_str = f" (Page {page})" if page else ""
|
| 68 |
+
context_parts.append(
|
| 69 |
+
f"[Source {i}: {source}{page_str}]\n{doc.page_content}"
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
return "\n\n---\n\n".join(context_parts)
|
| 73 |
+
|
| 74 |
+
def _format_sources(self, documents: List[Document]) -> List[SourceDocument]:
|
| 75 |
+
"""Convert retrieved documents to SourceDocument models."""
|
| 76 |
+
sources = []
|
| 77 |
+
seen = set()
|
| 78 |
+
for doc in documents:
|
| 79 |
+
source_key = (doc.metadata.get("source", ""), doc.metadata.get("page"))
|
| 80 |
+
if source_key not in seen:
|
| 81 |
+
seen.add(source_key)
|
| 82 |
+
sources.append(
|
| 83 |
+
SourceDocument(
|
| 84 |
+
content=doc.page_content[:300],
|
| 85 |
+
source=doc.metadata.get("source", "Unknown"),
|
| 86 |
+
page=doc.metadata.get("page"),
|
| 87 |
+
chunk_id=doc.metadata.get("chunk_id"),
|
| 88 |
+
)
|
| 89 |
+
)
|
| 90 |
+
return sources
|
| 91 |
+
|
| 92 |
+
def _build_chat_history(self, history: List[dict]) -> list:
|
| 93 |
+
"""Convert conversation history to LangChain message format."""
|
| 94 |
+
messages = []
|
| 95 |
+
for msg in history:
|
| 96 |
+
if msg["role"] == "user":
|
| 97 |
+
messages.append(HumanMessage(content=msg["content"]))
|
| 98 |
+
elif msg["role"] == "assistant":
|
| 99 |
+
messages.append(AIMessage(content=msg["content"]))
|
| 100 |
+
return messages
|
| 101 |
+
|
| 102 |
+
async def ask(
|
| 103 |
+
self,
|
| 104 |
+
question: str,
|
| 105 |
+
conversation_id: Optional[str] = None,
|
| 106 |
+
) -> dict:
|
| 107 |
+
"""Ask a question with RAG (non-streaming)."""
|
| 108 |
+
# Get or create conversation
|
| 109 |
+
conversation = await conversation_store.get_or_create(conversation_id)
|
| 110 |
+
conv_id = conversation.id
|
| 111 |
+
|
| 112 |
+
# Store user message
|
| 113 |
+
await conversation_store.add_message(conv_id, "user", question)
|
| 114 |
+
|
| 115 |
+
# Retrieve relevant documents
|
| 116 |
+
documents = await vector_store_service.similarity_search(question, k=6)
|
| 117 |
+
|
| 118 |
+
# Build context and history
|
| 119 |
+
context = self._format_context(documents)
|
| 120 |
+
chat_history = self._build_chat_history(
|
| 121 |
+
await conversation_store.get_history(conv_id, limit=8)
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Build and invoke chain
|
| 125 |
+
llm = self._get_llm(streaming=False)
|
| 126 |
+
prompt = self._build_prompt()
|
| 127 |
+
chain = prompt | llm
|
| 128 |
+
|
| 129 |
+
response = await chain.ainvoke(
|
| 130 |
+
{
|
| 131 |
+
"context": context,
|
| 132 |
+
"chat_history": chat_history[:-1], # Exclude current question
|
| 133 |
+
"question": question,
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
answer = response.content
|
| 138 |
+
sources = self._format_sources(documents)
|
| 139 |
+
|
| 140 |
+
# Store assistant response
|
| 141 |
+
await conversation_store.add_message(conv_id, "assistant", answer, sources)
|
| 142 |
+
|
| 143 |
+
return {
|
| 144 |
+
"answer": answer,
|
| 145 |
+
"conversation_id": conv_id,
|
| 146 |
+
"sources": sources,
|
| 147 |
+
}
|
| 148 |
+
|
| 149 |
+
async def ask_stream(
|
| 150 |
+
self,
|
| 151 |
+
question: str,
|
| 152 |
+
conversation_id: Optional[str] = None,
|
| 153 |
+
) -> AsyncGenerator[str, None]:
|
| 154 |
+
"""Ask a question with RAG (streaming via SSE)."""
|
| 155 |
+
# Get or create conversation
|
| 156 |
+
conversation = await conversation_store.get_or_create(conversation_id)
|
| 157 |
+
conv_id = conversation.id
|
| 158 |
+
|
| 159 |
+
# Store user message
|
| 160 |
+
await conversation_store.add_message(conv_id, "user", question)
|
| 161 |
+
|
| 162 |
+
# Retrieve relevant documents
|
| 163 |
+
documents = await vector_store_service.similarity_search(question, k=6)
|
| 164 |
+
|
| 165 |
+
# Build context and history
|
| 166 |
+
context = self._format_context(documents)
|
| 167 |
+
chat_history = self._build_chat_history(
|
| 168 |
+
await conversation_store.get_history(conv_id, limit=8)
|
| 169 |
+
)
|
| 170 |
+
sources = self._format_sources(documents)
|
| 171 |
+
|
| 172 |
+
# Send initial event with conversation ID and sources
|
| 173 |
+
yield f"data: {json.dumps({'type': 'metadata', 'conversation_id': conv_id, 'sources': [s.model_dump() for s in sources]})}\n\n"
|
| 174 |
+
|
| 175 |
+
# Stream the response using Gemini
|
| 176 |
+
llm = self._get_llm(streaming=True)
|
| 177 |
+
prompt = self._build_prompt()
|
| 178 |
+
chain = prompt | llm
|
| 179 |
+
|
| 180 |
+
full_response = []
|
| 181 |
+
|
| 182 |
+
try:
|
| 183 |
+
async for chunk in chain.astream(
|
| 184 |
+
{
|
| 185 |
+
"context": context,
|
| 186 |
+
"chat_history": chat_history[:-1],
|
| 187 |
+
"question": question,
|
| 188 |
+
}
|
| 189 |
+
):
|
| 190 |
+
if hasattr(chunk, 'content') and chunk.content:
|
| 191 |
+
full_response.append(chunk.content)
|
| 192 |
+
yield f"data: {json.dumps({'type': 'token', 'content': chunk.content})}\n\n"
|
| 193 |
+
except Exception as e:
|
| 194 |
+
yield f"data: {json.dumps({'type': 'error', 'content': str(e)})}\n\n"
|
| 195 |
+
|
| 196 |
+
# Store complete response
|
| 197 |
+
complete_answer = "".join(full_response)
|
| 198 |
+
await conversation_store.add_message(conv_id, "assistant", complete_answer, sources)
|
| 199 |
+
|
| 200 |
+
# Send done event
|
| 201 |
+
yield f"data: {json.dumps({'type': 'done'})}\n\n"
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
rag_service = RAGService()
|
app/services/vector_store.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Vector store service using Pinecone."""
|
| 2 |
+
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from langchain_huggingface import HuggingFaceEmbeddings
|
| 5 |
+
from langchain_pinecone import PineconeVectorStore
|
| 6 |
+
from langchain.schema import Document
|
| 7 |
+
from pinecone import Pinecone, ServerlessSpec
|
| 8 |
+
from app.config import settings
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class VectorStoreService:
|
| 12 |
+
"""Manages Pinecone vector store operations."""
|
| 13 |
+
|
| 14 |
+
def __init__(self):
|
| 15 |
+
self.embeddings: Optional[HuggingFaceEmbeddings] = None
|
| 16 |
+
self.vector_store: Optional[PineconeVectorStore] = None
|
| 17 |
+
self.pc: Optional[Pinecone] = None
|
| 18 |
+
self._initialized = False
|
| 19 |
+
|
| 20 |
+
async def initialize(self):
|
| 21 |
+
"""Initialize Pinecone and embeddings."""
|
| 22 |
+
if self._initialized:
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
# Initialize HuggingFace embeddings
|
| 26 |
+
self.embeddings = HuggingFaceEmbeddings(
|
| 27 |
+
model_name=settings.EMBEDDING_MODEL,
|
| 28 |
+
model_kwargs={'device': 'cpu'},
|
| 29 |
+
encode_kwargs={'normalize_embeddings': True}
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
# Initialize Pinecone
|
| 33 |
+
self.pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
| 34 |
+
|
| 35 |
+
# Create index if it doesn't exist
|
| 36 |
+
existing_indexes = [idx.name for idx in self.pc.list_indexes()]
|
| 37 |
+
if settings.PINECONE_INDEX_NAME not in existing_indexes:
|
| 38 |
+
self.pc.create_index(
|
| 39 |
+
name=settings.PINECONE_INDEX_NAME,
|
| 40 |
+
dimension=384, # all-MiniLM-L6-v2 dimension
|
| 41 |
+
metric="cosine",
|
| 42 |
+
spec=ServerlessSpec(cloud="aws", region="us-east-1"),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Initialize vector store
|
| 46 |
+
self.vector_store = PineconeVectorStore(
|
| 47 |
+
index_name=settings.PINECONE_INDEX_NAME,
|
| 48 |
+
embedding=self.embeddings,
|
| 49 |
+
pinecone_api_key=settings.PINECONE_API_KEY,
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
self._initialized = True
|
| 53 |
+
|
| 54 |
+
async def add_documents(self, documents: List[Document]) -> int:
|
| 55 |
+
"""Add documents to the vector store."""
|
| 56 |
+
if not self._initialized:
|
| 57 |
+
await self.initialize()
|
| 58 |
+
|
| 59 |
+
self.vector_store.add_documents(documents)
|
| 60 |
+
return len(documents)
|
| 61 |
+
|
| 62 |
+
async def similarity_search(
|
| 63 |
+
self, query: str, k: int = 4
|
| 64 |
+
) -> List[Document]:
|
| 65 |
+
"""Search for similar documents."""
|
| 66 |
+
if not self._initialized:
|
| 67 |
+
await self.initialize()
|
| 68 |
+
|
| 69 |
+
results = self.vector_store.similarity_search(query, k=k)
|
| 70 |
+
return results
|
| 71 |
+
|
| 72 |
+
async def similarity_search_with_score(
|
| 73 |
+
self, query: str, k: int = 4
|
| 74 |
+
) -> List[tuple]:
|
| 75 |
+
"""Search for similar documents with relevance scores."""
|
| 76 |
+
if not self._initialized:
|
| 77 |
+
await self.initialize()
|
| 78 |
+
|
| 79 |
+
results = self.vector_store.similarity_search_with_score(query, k=k)
|
| 80 |
+
return results
|
| 81 |
+
|
| 82 |
+
def get_retriever(self, k: int = 4):
|
| 83 |
+
"""Get a retriever for use in chains."""
|
| 84 |
+
if not self._initialized:
|
| 85 |
+
raise RuntimeError("Vector store not initialized")
|
| 86 |
+
|
| 87 |
+
return self.vector_store.as_retriever(
|
| 88 |
+
search_type="similarity",
|
| 89 |
+
search_kwargs={"k": k},
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
vector_store_service = VectorStoreService()
|
requirements.txt
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
fastapi==0.115.6
|
| 2 |
+
uvicorn[standard]==0.34.0
|
| 3 |
+
python-dotenv==1.0.1
|
| 4 |
+
langchain==0.3.14
|
| 5 |
+
langchain-core==0.3.14
|
| 6 |
+
langchain-google-genai==2.0.8
|
| 7 |
+
langchain-community==0.3.14
|
| 8 |
+
langchain-pinecone==0.2.0
|
| 9 |
+
langchain-huggingface==0.1.0
|
| 10 |
+
sentence-transformers==3.3.1
|
| 11 |
+
pinecone-client==5.0.1
|
| 12 |
+
google-generativeai==0.8.3
|
| 13 |
+
pypdf==5.1.0
|
| 14 |
+
python-multipart==0.0.20
|
| 15 |
+
pydantic==2.10.4
|
| 16 |
+
pydantic-settings==2.7.1
|
| 17 |
+
unstructured==0.16.12
|
| 18 |
+
aiofiles==24.1.0
|
| 19 |
+
uuid6==2024.7.10
|
| 20 |
+
sqlalchemy==2.0.25
|
| 21 |
+
aiosqlite==0.19.0
|