Spaces:
Running
Running
Crcs1225
commited on
Commit
·
1bc6a5b
0
Parent(s):
Initial commit of RAG FastAPI project
Browse files- .gitignore +10 -0
- Dockerfile +21 -0
- config.py +16 -0
- database.py +117 -0
- gemini_service.py +106 -0
- main.py +217 -0
- models.py +100 -0
- rag_system.py +285 -0
- requirements.txt +0 -0
- run.py +11 -0
.gitignore
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
env
|
| 3 |
+
# Environment variables file
|
| 4 |
+
.env.local
|
| 5 |
+
.env.*.local
|
| 6 |
+
|
| 7 |
+
__pycache__
|
| 8 |
+
*.pyc
|
| 9 |
+
*.pyo
|
| 10 |
+
venv
|
Dockerfile
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use a slim Python base
|
| 2 |
+
FROM python:3.11-slim
|
| 3 |
+
|
| 4 |
+
# Set working directory
|
| 5 |
+
WORKDIR /code
|
| 6 |
+
|
| 7 |
+
# Install system deps (if you need Mongo client, etc.)
|
| 8 |
+
RUN apt-get update && apt-get install -y build-essential
|
| 9 |
+
|
| 10 |
+
# Copy requirements and install
|
| 11 |
+
COPY requirements.txt .
|
| 12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 13 |
+
|
| 14 |
+
# Copy app code
|
| 15 |
+
COPY ./app /code/app
|
| 16 |
+
|
| 17 |
+
# Expose port
|
| 18 |
+
EXPOSE 7860
|
| 19 |
+
|
| 20 |
+
# Run FastAPI with uvicorn
|
| 21 |
+
CMD ["uvicorn", "run:app", "--host", "0.0.0.0", "--port", "7860"]
|
config.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic_settings import BaseSettings
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
class Settings(BaseSettings):
|
| 5 |
+
mongodb_atlas_uri: str
|
| 6 |
+
gemini_api_key: str
|
| 7 |
+
database_name: str = "product"
|
| 8 |
+
products_collection: str = "marketplace"
|
| 9 |
+
conversations_collection: str = "conversations"
|
| 10 |
+
embeddings_collection: str = "embeddings"
|
| 11 |
+
port: int = 8000
|
| 12 |
+
|
| 13 |
+
class Config:
|
| 14 |
+
env_file = ".env"
|
| 15 |
+
|
| 16 |
+
settings = Settings()
|
database.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from datetime import datetime
|
| 2 |
+
import motor.motor_asyncio
|
| 3 |
+
from bson import ObjectId
|
| 4 |
+
from typing import List, Optional, Dict, Any
|
| 5 |
+
from config import settings
|
| 6 |
+
import json
|
| 7 |
+
|
| 8 |
+
class MongoDB:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
self.client = motor.motor_asyncio.AsyncIOMotorClient(settings.mongodb_atlas_uri)
|
| 11 |
+
self.db = self.client[settings.database_name]
|
| 12 |
+
self.products = self.db[settings.products_collection]
|
| 13 |
+
self.conversations = self.db[settings.conversations_collection]
|
| 14 |
+
self.embeddings = self.db[settings.embeddings_collection]
|
| 15 |
+
|
| 16 |
+
async def get_product(self, product_id: str) -> Optional[Dict]:
|
| 17 |
+
return await self.products.find_one({"_id": ObjectId(product_id)})
|
| 18 |
+
|
| 19 |
+
async def get_products_by_category(self, category: str, limit: int = 10) -> List[Dict]:
|
| 20 |
+
# Your database might have different field names for category
|
| 21 |
+
cursor = self.products.find({"category": category}).limit(limit)
|
| 22 |
+
return await cursor.to_list(length=limit)
|
| 23 |
+
|
| 24 |
+
async def search_products(self, query: Dict, limit: int = 10) -> List[Dict]:
|
| 25 |
+
cursor = self.products.find(query).limit(limit)
|
| 26 |
+
return await cursor.to_list(length=limit)
|
| 27 |
+
|
| 28 |
+
async def get_all_products(self, limit: int = 50) -> List[Dict]:
|
| 29 |
+
cursor = self.products.find().limit(limit)
|
| 30 |
+
return await cursor.to_list(length=limit)
|
| 31 |
+
|
| 32 |
+
async def create_product(self, product_data: Dict) -> str:
|
| 33 |
+
result = await self.products.insert_one(product_data)
|
| 34 |
+
return str(result.inserted_id)
|
| 35 |
+
|
| 36 |
+
async def update_product(self, product_id: str, update_data: Dict) -> bool:
|
| 37 |
+
result = await self.products.update_one(
|
| 38 |
+
{"_id": ObjectId(product_id)},
|
| 39 |
+
{"$set": update_data}
|
| 40 |
+
)
|
| 41 |
+
return result.modified_count > 0
|
| 42 |
+
|
| 43 |
+
async def delete_product(self, product_id: str) -> bool:
|
| 44 |
+
result = await self.products.delete_one({"_id": ObjectId(product_id)})
|
| 45 |
+
return result.deleted_count > 0
|
| 46 |
+
|
| 47 |
+
async def get_conversation(self, conversation_id: str) -> Optional[Dict]:
|
| 48 |
+
return await self.conversations.find_one({"_id": ObjectId(conversation_id)})
|
| 49 |
+
|
| 50 |
+
async def create_conversation(self, user_id: str) -> str:
|
| 51 |
+
conversation_data = {
|
| 52 |
+
"user_id": user_id,
|
| 53 |
+
"messages": [],
|
| 54 |
+
"created_at": datetime.now(),
|
| 55 |
+
"updated_at": datetime.now()
|
| 56 |
+
}
|
| 57 |
+
result = await self.conversations.insert_one(conversation_data)
|
| 58 |
+
return str(result.inserted_id)
|
| 59 |
+
|
| 60 |
+
async def add_message_to_conversation(self, conversation_id: str, message: Dict) -> bool:
|
| 61 |
+
result = await self.conversations.update_one(
|
| 62 |
+
{"_id": ObjectId(conversation_id)},
|
| 63 |
+
{
|
| 64 |
+
"$push": {"messages": message},
|
| 65 |
+
"$set": {"updated_at": datetime.now()}
|
| 66 |
+
}
|
| 67 |
+
)
|
| 68 |
+
return result.modified_count > 0
|
| 69 |
+
|
| 70 |
+
async def get_user_conversations(self, user_id: str, limit: int = 10) -> List[Dict]:
|
| 71 |
+
cursor = self.conversations.find({"user_id": user_id}).sort("updated_at", -1).limit(limit)
|
| 72 |
+
return await cursor.to_list(length=limit)
|
| 73 |
+
|
| 74 |
+
async def store_embedding(self, text: str, embedding: List[float], metadata: Dict) -> str:
|
| 75 |
+
doc = {
|
| 76 |
+
"text": text,
|
| 77 |
+
"embedding": embedding,
|
| 78 |
+
"metadata": metadata,
|
| 79 |
+
"created_at": datetime.now()
|
| 80 |
+
}
|
| 81 |
+
result = await self.embeddings.insert_one(doc)
|
| 82 |
+
return str(result.inserted_id)
|
| 83 |
+
|
| 84 |
+
async def find_similar_embeddings(self, embedding: List[float], limit: int = 5) -> List[Dict]:
|
| 85 |
+
# This is a simplified version - in production, you'd use vector search
|
| 86 |
+
pipeline = [
|
| 87 |
+
{
|
| 88 |
+
"$addFields": {
|
| 89 |
+
"similarity": {
|
| 90 |
+
"$sqrt": {
|
| 91 |
+
"$sum": {
|
| 92 |
+
"$map": {
|
| 93 |
+
"input": {"$range": [0, {"$size": "$embedding"}]},
|
| 94 |
+
"as": "idx",
|
| 95 |
+
"in": {
|
| 96 |
+
"$pow": [
|
| 97 |
+
{"$subtract": [
|
| 98 |
+
{"$arrayElemAt": ["$embedding", "$$idx"]},
|
| 99 |
+
{"$arrayElemAt": [embedding, "$$idx"]}
|
| 100 |
+
]},
|
| 101 |
+
2
|
| 102 |
+
]
|
| 103 |
+
}
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
}
|
| 107 |
+
}
|
| 108 |
+
}
|
| 109 |
+
},
|
| 110 |
+
{"$sort": {"similarity": 1}},
|
| 111 |
+
{"$limit": limit}
|
| 112 |
+
]
|
| 113 |
+
cursor = self.embeddings.aggregate(pipeline)
|
| 114 |
+
return await cursor.to_list(length=limit)
|
| 115 |
+
|
| 116 |
+
# Database instance
|
| 117 |
+
db = MongoDB()
|
gemini_service.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import google.generativeai as genai
|
| 2 |
+
from typing import List, Dict, Any, Optional
|
| 3 |
+
import asyncio
|
| 4 |
+
import aiohttp
|
| 5 |
+
import json
|
| 6 |
+
from config import settings
|
| 7 |
+
|
| 8 |
+
class GeminiService:
|
| 9 |
+
def __init__(self):
|
| 10 |
+
genai.configure(api_key=settings.gemini_api_key)
|
| 11 |
+
self.model = genai.GenerativeModel('gemini-2.5-flash')
|
| 12 |
+
|
| 13 |
+
async def generate_response(self, prompt: str, context: str = "") -> str:
|
| 14 |
+
"""Generate response using Gemini API with context"""
|
| 15 |
+
try:
|
| 16 |
+
full_prompt = f"""
|
| 17 |
+
Context Information:
|
| 18 |
+
{context}
|
| 19 |
+
|
| 20 |
+
User Question: {prompt}
|
| 21 |
+
|
| 22 |
+
You are a helpful shopping assistant for Daddy's Shop. Use the context information above to answer the user's question accurately and helpfully. If the context doesn't contain relevant information, use your general knowledge but be honest about limitations.
|
| 23 |
+
|
| 24 |
+
Provide a friendly, professional response focused on helping with shopping needs.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
# Run in thread pool since Gemini doesn't have native async support
|
| 28 |
+
loop = asyncio.get_event_loop()
|
| 29 |
+
response = await loop.run_in_executor(
|
| 30 |
+
None,
|
| 31 |
+
lambda: self.model.generate_content(full_prompt)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
return response.text
|
| 35 |
+
|
| 36 |
+
except Exception as e:
|
| 37 |
+
print(f"Gemini API error: {e}")
|
| 38 |
+
return "I apologize, but I'm having trouble processing your request right now. Please try again later."
|
| 39 |
+
|
| 40 |
+
async def generate_embedding(self, text: str) -> List[float]:
|
| 41 |
+
"""Generate embeddings for text using Gemini"""
|
| 42 |
+
try:
|
| 43 |
+
# Note: Gemini doesn't have direct embedding API, so we'll use a workaround
|
| 44 |
+
# For production, consider using SentenceTransformers or another embedding service
|
| 45 |
+
embedding_model = genai.GenerativeModel('embedding-001')
|
| 46 |
+
|
| 47 |
+
loop = asyncio.get_event_loop()
|
| 48 |
+
result = await loop.run_in_executor(
|
| 49 |
+
None,
|
| 50 |
+
lambda: genai.embed_content(
|
| 51 |
+
model=embedding_model,
|
| 52 |
+
content=text,
|
| 53 |
+
task_type="retrieval_document"
|
| 54 |
+
)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
return result['embedding']
|
| 58 |
+
|
| 59 |
+
except Exception as e:
|
| 60 |
+
print(f"Embedding generation error: {e}")
|
| 61 |
+
# Fallback to simple embedding (in production, use proper embedding model)
|
| 62 |
+
return [0.0] * 384 # Default dimension
|
| 63 |
+
|
| 64 |
+
async def classify_intent(self, message: str) -> Dict[str, Any]:
|
| 65 |
+
"""Classify user intent using Gemini"""
|
| 66 |
+
prompt = f"""
|
| 67 |
+
Classify the following user message into one of these intents:
|
| 68 |
+
- product_inquiry: Questions about products, features, availability
|
| 69 |
+
- pricing: Questions about costs, discounts, prices
|
| 70 |
+
- shipping: Questions about delivery, shipping costs, timelines
|
| 71 |
+
- returns: Questions about returns, refunds, exchanges
|
| 72 |
+
- support: General customer support, contact information
|
| 73 |
+
- greeting: Hello, hi, greetings
|
| 74 |
+
- unknown: Cannot classify
|
| 75 |
+
|
| 76 |
+
Message: "{message}"
|
| 77 |
+
|
| 78 |
+
Return ONLY a JSON object with:
|
| 79 |
+
{{
|
| 80 |
+
"intent": "classified_intent",
|
| 81 |
+
"confidence": 0.95,
|
| 82 |
+
"entities": ["extracted_entities", "if_any"]
|
| 83 |
+
}}
|
| 84 |
+
"""
|
| 85 |
+
|
| 86 |
+
try:
|
| 87 |
+
loop = asyncio.get_event_loop()
|
| 88 |
+
response = await loop.run_in_executor(
|
| 89 |
+
None,
|
| 90 |
+
lambda: self.model.generate_content(prompt)
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# Parse JSON response
|
| 94 |
+
import re
|
| 95 |
+
json_match = re.search(r'\{.*\}', response.text, re.DOTALL)
|
| 96 |
+
if json_match:
|
| 97 |
+
return json.loads(json_match.group())
|
| 98 |
+
else:
|
| 99 |
+
return {"intent": "unknown", "confidence": 0.0, "entities": []}
|
| 100 |
+
|
| 101 |
+
except Exception as e:
|
| 102 |
+
print(f"Intent classification error: {e}")
|
| 103 |
+
return {"intent": "unknown", "confidence": 0.0, "entities": []}
|
| 104 |
+
|
| 105 |
+
# Gemini service instance
|
| 106 |
+
gemini_service = GeminiService()
|
main.py
ADDED
|
@@ -0,0 +1,217 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI, HTTPException, Depends
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
import uuid
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
from config import settings
|
| 8 |
+
from database import db
|
| 9 |
+
from rag_system import rag_system
|
| 10 |
+
from models import (
|
| 11 |
+
Product, ProductCreate, ChatRequest, ChatResponse,
|
| 12 |
+
SearchRequest, Conversation, ChatMessage
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
app = FastAPI(
|
| 16 |
+
title="Daddy's Shop RAG Chatbot API",
|
| 17 |
+
description="AI-powered shopping assistant with RAG using MongoDB and Gemini",
|
| 18 |
+
version="1.0.0"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# CORS middleware
|
| 22 |
+
app.add_middleware(
|
| 23 |
+
CORSMiddleware,
|
| 24 |
+
allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "https://yourdomain.com"],
|
| 25 |
+
allow_credentials=True,
|
| 26 |
+
allow_methods=["*"],
|
| 27 |
+
allow_headers=["*"],
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
@app.get("/")
|
| 31 |
+
async def root():
|
| 32 |
+
"""Health check endpoint"""
|
| 33 |
+
return {
|
| 34 |
+
"message": "Daddy's Shop RAG Chatbot API is running!",
|
| 35 |
+
"version": "1.0.0",
|
| 36 |
+
"status": "healthy"
|
| 37 |
+
}
|
| 38 |
+
|
| 39 |
+
@app.post("/chat", response_model=ChatResponse)
|
| 40 |
+
async def chat_endpoint(request: ChatRequest):
|
| 41 |
+
"""
|
| 42 |
+
Main chatbot endpoint with RAG capabilities
|
| 43 |
+
"""
|
| 44 |
+
try:
|
| 45 |
+
# Create new conversation if no ID provided
|
| 46 |
+
if not request.conversation_id:
|
| 47 |
+
conversation_id = await db.create_conversation(request.user_id)
|
| 48 |
+
else:
|
| 49 |
+
conversation_id = request.conversation_id
|
| 50 |
+
# Verify conversation exists
|
| 51 |
+
conversation = await db.get_conversation(conversation_id)
|
| 52 |
+
if not conversation:
|
| 53 |
+
conversation_id = await db.create_conversation(request.user_id)
|
| 54 |
+
|
| 55 |
+
# Add user message to conversation
|
| 56 |
+
user_message = ChatMessage(
|
| 57 |
+
sender="user",
|
| 58 |
+
text=request.message
|
| 59 |
+
)
|
| 60 |
+
await db.add_message_to_conversation(conversation_id, user_message.dict())
|
| 61 |
+
|
| 62 |
+
# Get conversation history for context
|
| 63 |
+
conversation = await db.get_conversation(conversation_id)
|
| 64 |
+
history = conversation.get('messages', []) if conversation else []
|
| 65 |
+
|
| 66 |
+
# Generate response using RAG system
|
| 67 |
+
rag_result = await rag_system.generate_chat_response(request.message, history)
|
| 68 |
+
|
| 69 |
+
# Add bot response to conversation
|
| 70 |
+
bot_message = ChatMessage(
|
| 71 |
+
sender="bot",
|
| 72 |
+
text=rag_result["response"]
|
| 73 |
+
)
|
| 74 |
+
await db.add_message_to_conversation(conversation_id, bot_message.dict())
|
| 75 |
+
|
| 76 |
+
return ChatResponse(
|
| 77 |
+
response=rag_result["response"],
|
| 78 |
+
conversation_id=conversation_id,
|
| 79 |
+
suggested_questions=rag_result["suggested_questions"],
|
| 80 |
+
relevant_products=rag_result["relevant_products"],
|
| 81 |
+
intent=rag_result["intent"],
|
| 82 |
+
confidence=rag_result["confidence"]
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
except Exception as e:
|
| 86 |
+
raise HTTPException(status_code=500, detail=f"Chat processing error: {str(e)}")
|
| 87 |
+
|
| 88 |
+
@app.get("/products", response_model=List[Product])
|
| 89 |
+
async def get_products(category: Optional[str] = None, limit: int = 20):
|
| 90 |
+
"""Get products with optional category filter"""
|
| 91 |
+
try:
|
| 92 |
+
if category:
|
| 93 |
+
products = await db.get_products_by_category(category, limit)
|
| 94 |
+
else:
|
| 95 |
+
products = await db.get_all_products(limit)
|
| 96 |
+
return products
|
| 97 |
+
except Exception as e:
|
| 98 |
+
raise HTTPException(status_code=500, detail=f"Error fetching products: {str(e)}")
|
| 99 |
+
|
| 100 |
+
@app.get("/products/{product_id}", response_model=Product)
|
| 101 |
+
async def get_product(product_id: str):
|
| 102 |
+
"""Get specific product by ID"""
|
| 103 |
+
product = await db.get_product(product_id)
|
| 104 |
+
if not product:
|
| 105 |
+
raise HTTPException(status_code=404, detail="Product not found")
|
| 106 |
+
return product
|
| 107 |
+
|
| 108 |
+
@app.post("/products", response_model=dict)
|
| 109 |
+
async def create_product(product: ProductCreate):
|
| 110 |
+
"""Create a new product"""
|
| 111 |
+
try:
|
| 112 |
+
product_data = product.dict()
|
| 113 |
+
product_data["_id"] = str(uuid.uuid4())[:8] # Simple ID generation
|
| 114 |
+
product_id = await db.create_product(product_data)
|
| 115 |
+
return {"message": "Product created successfully", "product_id": product_id}
|
| 116 |
+
except Exception as e:
|
| 117 |
+
raise HTTPException(status_code=500, detail=f"Error creating product: {str(e)}")
|
| 118 |
+
|
| 119 |
+
@app.post("/search", response_model=List[Product])
|
| 120 |
+
async def search_products(request: SearchRequest):
|
| 121 |
+
"""Semantic search for products"""
|
| 122 |
+
try:
|
| 123 |
+
products = await rag_system.retrieve_relevant_products(
|
| 124 |
+
request.query,
|
| 125 |
+
request.category,
|
| 126 |
+
request.max_results
|
| 127 |
+
)
|
| 128 |
+
return products
|
| 129 |
+
except Exception as e:
|
| 130 |
+
raise HTTPException(status_code=500, detail=f"Search error: {str(e)}")
|
| 131 |
+
|
| 132 |
+
@app.get("/conversations/{user_id}", response_model=List[Conversation])
|
| 133 |
+
async def get_user_conversations(user_id: str, limit: int = 10):
|
| 134 |
+
"""Get user's conversation history"""
|
| 135 |
+
try:
|
| 136 |
+
conversations = await db.get_user_conversations(user_id, limit)
|
| 137 |
+
return conversations
|
| 138 |
+
except Exception as e:
|
| 139 |
+
raise HTTPException(status_code=500, detail=f"Error fetching conversations: {str(e)}")
|
| 140 |
+
|
| 141 |
+
@app.get("/intents")
|
| 142 |
+
async def get_available_intents():
|
| 143 |
+
"""Get information about available intents"""
|
| 144 |
+
return {
|
| 145 |
+
"intents": [
|
| 146 |
+
"product_inquiry", "pricing", "shipping",
|
| 147 |
+
"returns", "support", "greeting", "unknown"
|
| 148 |
+
],
|
| 149 |
+
"description": "Intent classification for user messages"
|
| 150 |
+
}
|
| 151 |
+
@app.get("/test-products")
|
| 152 |
+
async def test_products(limit: int = 3):
|
| 153 |
+
"""Test endpoint to see transformed products"""
|
| 154 |
+
try:
|
| 155 |
+
products = await db.get_all_products(limit)
|
| 156 |
+
transformed_products = [rag_system._transform_product_doc(product) for product in products]
|
| 157 |
+
return {
|
| 158 |
+
"original_count": len(products),
|
| 159 |
+
"transformed_count": len(transformed_products),
|
| 160 |
+
"original_sample": products[0] if products else {},
|
| 161 |
+
"transformed_sample": transformed_products[0] if transformed_products else {},
|
| 162 |
+
"all_transformed": transformed_products
|
| 163 |
+
}
|
| 164 |
+
except Exception as e:
|
| 165 |
+
raise HTTPException(status_code=500, detail=f"Test error: {str(e)}")
|
| 166 |
+
# Sample data initialization
|
| 167 |
+
@app.on_event("startup")
|
| 168 |
+
async def startup_event():
|
| 169 |
+
"""Initialize sample data if needed"""
|
| 170 |
+
try:
|
| 171 |
+
# Check if we have any products
|
| 172 |
+
products = await db.get_all_products(1)
|
| 173 |
+
if not products:
|
| 174 |
+
await _initialize_sample_data()
|
| 175 |
+
except Exception as e:
|
| 176 |
+
print(f"Startup initialization error: {e}")
|
| 177 |
+
|
| 178 |
+
async def _initialize_sample_data():
|
| 179 |
+
"""Initialize sample product data"""
|
| 180 |
+
sample_products = [
|
| 181 |
+
{
|
| 182 |
+
"name": "Wireless Bluetooth Earbuds",
|
| 183 |
+
"description": "High-quality wireless earbuds with noise cancellation and 24-hour battery life.",
|
| 184 |
+
"price": 79.99,
|
| 185 |
+
"category": "electronics",
|
| 186 |
+
"in_stock": True,
|
| 187 |
+
"tags": ["wireless", "bluetooth", "audio", "noise-cancellation"],
|
| 188 |
+
"features": ["Noise Cancellation", "24h Battery", "Water Resistant", "Touch Controls"]
|
| 189 |
+
},
|
| 190 |
+
{
|
| 191 |
+
"name": "Smart Fitness Watch",
|
| 192 |
+
"description": "Advanced fitness tracker with heart rate monitoring, GPS, and smartphone connectivity.",
|
| 193 |
+
"price": 199.99,
|
| 194 |
+
"category": "electronics",
|
| 195 |
+
"in_stock": True,
|
| 196 |
+
"tags": ["fitness", "smartwatch", "health", "tracking"],
|
| 197 |
+
"features": ["Heart Rate Monitor", "GPS", "Sleep Tracking", "Waterproof"]
|
| 198 |
+
},
|
| 199 |
+
{
|
| 200 |
+
"name": "Organic Cotton T-Shirt",
|
| 201 |
+
"description": "Comfortable and sustainable organic cotton t-shirt available in multiple colors.",
|
| 202 |
+
"price": 24.99,
|
| 203 |
+
"category": "clothing",
|
| 204 |
+
"in_stock": True,
|
| 205 |
+
"tags": ["cotton", "organic", "sustainable", "casual"],
|
| 206 |
+
"features": ["Organic Cotton", "Machine Washable", "Multiple Colors"]
|
| 207 |
+
}
|
| 208 |
+
]
|
| 209 |
+
|
| 210 |
+
for product in sample_products:
|
| 211 |
+
await db.create_product(product)
|
| 212 |
+
|
| 213 |
+
print("Sample data initialized successfully")
|
| 214 |
+
|
| 215 |
+
if __name__ == "__main__":
|
| 216 |
+
import uvicorn
|
| 217 |
+
uvicorn.run(app, host="0.0.0.0", port=settings.port)
|
models.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field, validator
|
| 2 |
+
from typing import List, Optional, Dict, Any
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from enum import Enum
|
| 5 |
+
|
| 6 |
+
class ProductCategory(str, Enum):
|
| 7 |
+
ELECTRONICS = "electronics"
|
| 8 |
+
CLOTHING = "clothing"
|
| 9 |
+
HOME = "home"
|
| 10 |
+
BEAUTY = "beauty"
|
| 11 |
+
SPORTS = "sports"
|
| 12 |
+
BOOKS = "books"
|
| 13 |
+
OTHER = "other"
|
| 14 |
+
|
| 15 |
+
class ProductCreate(BaseModel):
|
| 16 |
+
name: str
|
| 17 |
+
description: str = ""
|
| 18 |
+
price: float = 0.0
|
| 19 |
+
category: ProductCategory = ProductCategory.OTHER
|
| 20 |
+
in_stock: bool = True
|
| 21 |
+
tags: List[str] = []
|
| 22 |
+
features: List[str] = []
|
| 23 |
+
image_url: Optional[str] = None
|
| 24 |
+
|
| 25 |
+
@validator('price', pre=True)
|
| 26 |
+
def validate_price(cls, v):
|
| 27 |
+
if v is None:
|
| 28 |
+
return 0.0
|
| 29 |
+
try:
|
| 30 |
+
return float(v)
|
| 31 |
+
except (TypeError, ValueError):
|
| 32 |
+
return 0.0
|
| 33 |
+
|
| 34 |
+
@validator('category', pre=True)
|
| 35 |
+
def validate_category(cls, v):
|
| 36 |
+
if isinstance(v, ProductCategory):
|
| 37 |
+
return v
|
| 38 |
+
try:
|
| 39 |
+
return ProductCategory(v)
|
| 40 |
+
except ValueError:
|
| 41 |
+
return ProductCategory.OTHER
|
| 42 |
+
|
| 43 |
+
class Product(BaseModel):
|
| 44 |
+
id: str
|
| 45 |
+
name: str
|
| 46 |
+
description: str = ""
|
| 47 |
+
price: float = 0.0
|
| 48 |
+
category: ProductCategory = ProductCategory.OTHER
|
| 49 |
+
in_stock: bool = True
|
| 50 |
+
tags: List[str] = []
|
| 51 |
+
features: List[str] = []
|
| 52 |
+
image_url: Optional[str] = None
|
| 53 |
+
|
| 54 |
+
@validator('price', pre=True)
|
| 55 |
+
def validate_price(cls, v):
|
| 56 |
+
if v is None:
|
| 57 |
+
return 0.0
|
| 58 |
+
try:
|
| 59 |
+
return float(v)
|
| 60 |
+
except (TypeError, ValueError):
|
| 61 |
+
return 0.0
|
| 62 |
+
|
| 63 |
+
@validator('category', pre=True)
|
| 64 |
+
def validate_category(cls, v):
|
| 65 |
+
if isinstance(v, ProductCategory):
|
| 66 |
+
return v
|
| 67 |
+
try:
|
| 68 |
+
return ProductCategory(v)
|
| 69 |
+
except ValueError:
|
| 70 |
+
return ProductCategory.OTHER
|
| 71 |
+
|
| 72 |
+
class ChatMessage(BaseModel):
|
| 73 |
+
sender: str
|
| 74 |
+
text: str
|
| 75 |
+
timestamp: datetime = Field(default_factory=datetime.now)
|
| 76 |
+
|
| 77 |
+
class ChatRequest(BaseModel):
|
| 78 |
+
message: str
|
| 79 |
+
conversation_id: Optional[str] = None
|
| 80 |
+
user_id: Optional[str] = "anonymous"
|
| 81 |
+
|
| 82 |
+
class ChatResponse(BaseModel):
|
| 83 |
+
response: str
|
| 84 |
+
conversation_id: str
|
| 85 |
+
suggested_questions: List[str] = []
|
| 86 |
+
relevant_products: List[Product] = []
|
| 87 |
+
intent: str
|
| 88 |
+
confidence: float
|
| 89 |
+
|
| 90 |
+
class SearchRequest(BaseModel):
|
| 91 |
+
query: str
|
| 92 |
+
category: Optional[ProductCategory] = None
|
| 93 |
+
max_results: int = 5
|
| 94 |
+
|
| 95 |
+
class Conversation(BaseModel):
|
| 96 |
+
id: str
|
| 97 |
+
user_id: str
|
| 98 |
+
messages: List[ChatMessage] = []
|
| 99 |
+
created_at: datetime = Field(default_factory=datetime.now)
|
| 100 |
+
updated_at: datetime = Field(default_factory=datetime.now)
|
rag_system.py
ADDED
|
@@ -0,0 +1,285 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Dict, Any, Optional
|
| 2 |
+
from database import db
|
| 3 |
+
from gemini_service import gemini_service
|
| 4 |
+
from models import Product, ProductCategory
|
| 5 |
+
import numpy as np
|
| 6 |
+
from sentence_transformers import SentenceTransformer
|
| 7 |
+
import asyncio
|
| 8 |
+
|
| 9 |
+
class RAGSystem:
|
| 10 |
+
def __init__(self):
|
| 11 |
+
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 12 |
+
|
| 13 |
+
def _transform_product_doc(self, product_doc: Dict) -> Dict:
|
| 14 |
+
"""Transform MongoDB document to match Product model using actual field names"""
|
| 15 |
+
try:
|
| 16 |
+
# Extract fields from your actual database schema
|
| 17 |
+
product_id = str(product_doc.get('_id', '')) or str(product_doc.get('product_id', ''))
|
| 18 |
+
|
| 19 |
+
# Map to our expected fields
|
| 20 |
+
transformed = {
|
| 21 |
+
"id": product_id,
|
| 22 |
+
"name": product_doc.get('title', 'Unnamed Product'),
|
| 23 |
+
"description": product_doc.get('product_description', 'No description available'),
|
| 24 |
+
"price": float(product_doc.get('final_price', 0.0)),
|
| 25 |
+
"category": self._map_category(product_doc.get('category', 'other')),
|
| 26 |
+
"in_stock": True, # Default to True since we don't have this field
|
| 27 |
+
"tags": self._extract_tags(product_doc),
|
| 28 |
+
"features": self._extract_features(product_doc),
|
| 29 |
+
"image_url": self._get_first_image(product_doc.get('images', '')),
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
return transformed
|
| 33 |
+
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"Error transforming product {product_doc.get('_id')}: {e}")
|
| 36 |
+
# Return a safe default product
|
| 37 |
+
return {
|
| 38 |
+
"id": str(product_doc.get('_id', 'unknown')),
|
| 39 |
+
"name": product_doc.get('title', 'Unnamed Product'),
|
| 40 |
+
"description": "Product information unavailable",
|
| 41 |
+
"price": 0.0,
|
| 42 |
+
"category": ProductCategory.OTHER,
|
| 43 |
+
"in_stock": True,
|
| 44 |
+
"tags": [],
|
| 45 |
+
"features": [],
|
| 46 |
+
}
|
| 47 |
+
|
| 48 |
+
def _map_category(self, raw_category: str) -> ProductCategory:
|
| 49 |
+
"""Map raw category string to ProductCategory enum"""
|
| 50 |
+
if not raw_category:
|
| 51 |
+
return ProductCategory.OTHER
|
| 52 |
+
|
| 53 |
+
category_lower = raw_category.lower()
|
| 54 |
+
|
| 55 |
+
# Map based on your actual category values
|
| 56 |
+
if any(keyword in category_lower for keyword in ['electronic', 'tech', 'computer', 'phone']):
|
| 57 |
+
return ProductCategory.ELECTRONICS
|
| 58 |
+
elif any(keyword in category_lower for keyword in ['cloth', 'fashion', 'wear', 'top', 'dress', 'shirt', 'jeans']):
|
| 59 |
+
return ProductCategory.CLOTHING
|
| 60 |
+
elif any(keyword in category_lower for keyword in ['home', 'garden', 'furniture', 'decor']):
|
| 61 |
+
return ProductCategory.HOME
|
| 62 |
+
elif any(keyword in category_lower for keyword in ['beauty', 'cosmetic', 'skin', 'hair', 'cream', 'mask', 'makeup']):
|
| 63 |
+
return ProductCategory.BEAUTY
|
| 64 |
+
elif any(keyword in category_lower for keyword in ['sport', 'fitness', 'exercise', 'gym']):
|
| 65 |
+
return ProductCategory.SPORTS
|
| 66 |
+
elif any(keyword in category_lower for keyword in ['book', 'literature']):
|
| 67 |
+
return ProductCategory.BOOKS
|
| 68 |
+
else:
|
| 69 |
+
return ProductCategory.OTHER
|
| 70 |
+
|
| 71 |
+
def _extract_tags(self, product_doc: Dict) -> List[str]:
|
| 72 |
+
"""Extract tags from product details"""
|
| 73 |
+
tags = []
|
| 74 |
+
|
| 75 |
+
# Add category as a tag
|
| 76 |
+
category = product_doc.get('category')
|
| 77 |
+
if category:
|
| 78 |
+
tags.append(category)
|
| 79 |
+
|
| 80 |
+
# Extract from product details if available
|
| 81 |
+
details = product_doc.get('product_details', '')
|
| 82 |
+
if 'Cotton' in details:
|
| 83 |
+
tags.append('Cotton')
|
| 84 |
+
if 'Polyester' in details:
|
| 85 |
+
tags.append('Polyester')
|
| 86 |
+
|
| 87 |
+
return tags
|
| 88 |
+
|
| 89 |
+
def _extract_features(self, product_doc: Dict) -> List[str]:
|
| 90 |
+
"""Extract features from product details"""
|
| 91 |
+
features = []
|
| 92 |
+
|
| 93 |
+
# Extract key features from product details
|
| 94 |
+
details = product_doc.get('product_details', '')
|
| 95 |
+
|
| 96 |
+
# Look for common features
|
| 97 |
+
feature_keywords = [
|
| 98 |
+
'Machine wash', 'Hand wash', 'Dry clean', 'Cotton',
|
| 99 |
+
'Polyester', 'Elastane', 'Spaghetti', 'Sleeveless',
|
| 100 |
+
'Solid pattern', 'Sweetheart neck'
|
| 101 |
+
]
|
| 102 |
+
|
| 103 |
+
for keyword in feature_keywords:
|
| 104 |
+
if keyword in details:
|
| 105 |
+
features.append(keyword)
|
| 106 |
+
|
| 107 |
+
# Add rating as a feature
|
| 108 |
+
rating = product_doc.get('rating')
|
| 109 |
+
if rating:
|
| 110 |
+
features.append(f'{rating}★ Rating')
|
| 111 |
+
|
| 112 |
+
# Add delivery options
|
| 113 |
+
delivery = product_doc.get('delivery_options', '')
|
| 114 |
+
if 'Pay on delivery' in delivery:
|
| 115 |
+
features.append('Pay on Delivery')
|
| 116 |
+
if 'Easy returns' in delivery:
|
| 117 |
+
features.append('Easy Returns')
|
| 118 |
+
|
| 119 |
+
return features
|
| 120 |
+
|
| 121 |
+
def _get_first_image(self, images_str: str) -> Optional[str]:
|
| 122 |
+
"""Extract first image URL from comma-separated string"""
|
| 123 |
+
if not images_str:
|
| 124 |
+
return None
|
| 125 |
+
images = images_str.split(',')
|
| 126 |
+
return images[0].strip() if images else None
|
| 127 |
+
|
| 128 |
+
# ... rest of your existing methods remain the same
|
| 129 |
+
async def retrieve_relevant_products(self, query: str, category: Optional[str] = None, limit: int = 5) -> List[Dict]:
|
| 130 |
+
"""Retrieve and transform relevant products"""
|
| 131 |
+
try:
|
| 132 |
+
# Generate query embedding
|
| 133 |
+
query_embedding = await self._get_embedding(query)
|
| 134 |
+
|
| 135 |
+
# Get all products or filtered by category
|
| 136 |
+
if category:
|
| 137 |
+
products = await db.get_products_by_category(category, limit=50)
|
| 138 |
+
else:
|
| 139 |
+
products = await db.get_all_products(limit=100)
|
| 140 |
+
|
| 141 |
+
# Transform products first
|
| 142 |
+
transformed_products = [self._transform_product_doc(product) for product in products]
|
| 143 |
+
|
| 144 |
+
# Calculate similarity scores on transformed products
|
| 145 |
+
scored_products = []
|
| 146 |
+
for product in transformed_products:
|
| 147 |
+
product_text = f"{product.get('name', '')} {product.get('description', '')} {' '.join(product.get('tags', []))}"
|
| 148 |
+
product_embedding = await self._get_embedding(product_text)
|
| 149 |
+
|
| 150 |
+
similarity = self._cosine_similarity(query_embedding, product_embedding)
|
| 151 |
+
scored_products.append((product, similarity))
|
| 152 |
+
|
| 153 |
+
# Sort by similarity and return top results
|
| 154 |
+
scored_products.sort(key=lambda x: x[1], reverse=True)
|
| 155 |
+
return [product for product, score in scored_products[:limit]]
|
| 156 |
+
|
| 157 |
+
except Exception as e:
|
| 158 |
+
print(f"Product retrieval error: {e}")
|
| 159 |
+
# Fallback to simple keyword search with transformation
|
| 160 |
+
return await self._keyword_search_products(query, category, limit)
|
| 161 |
+
|
| 162 |
+
async def _keyword_search_products(self, query: str, category: Optional[str], limit: int) -> List[Dict]:
|
| 163 |
+
"""Fallback keyword-based product search with transformation"""
|
| 164 |
+
search_terms = query.lower().split()
|
| 165 |
+
|
| 166 |
+
if category:
|
| 167 |
+
products = await db.get_products_by_category(category, limit=50)
|
| 168 |
+
else:
|
| 169 |
+
products = await db.get_all_products(limit=100)
|
| 170 |
+
|
| 171 |
+
# Transform all products first
|
| 172 |
+
transformed_products = [self._transform_product_doc(product) for product in products]
|
| 173 |
+
|
| 174 |
+
scored_products = []
|
| 175 |
+
for product in transformed_products:
|
| 176 |
+
score = 0
|
| 177 |
+
product_text = f"{product.get('name', '')} {product.get('description', '')} {' '.join(product.get('tags', []))}".lower()
|
| 178 |
+
|
| 179 |
+
for term in search_terms:
|
| 180 |
+
if term in product_text:
|
| 181 |
+
score += 1
|
| 182 |
+
if term in product.get('name', '').lower():
|
| 183 |
+
score += 2 # Higher weight for name matches
|
| 184 |
+
|
| 185 |
+
if score > 0:
|
| 186 |
+
scored_products.append((product, score))
|
| 187 |
+
|
| 188 |
+
scored_products.sort(key=lambda x: x[1], reverse=True)
|
| 189 |
+
return [product for product, score in scored_products[:limit]]
|
| 190 |
+
|
| 191 |
+
async def _get_embedding(self, text: str) -> List[float]:
|
| 192 |
+
"""Get embedding for text using available methods"""
|
| 193 |
+
try:
|
| 194 |
+
embedding = await gemini_service.generate_embedding(text)
|
| 195 |
+
if embedding and len(embedding) > 10:
|
| 196 |
+
return embedding
|
| 197 |
+
except:
|
| 198 |
+
pass
|
| 199 |
+
|
| 200 |
+
embedding = self.embedding_model.encode(text)
|
| 201 |
+
return embedding.tolist()
|
| 202 |
+
|
| 203 |
+
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
| 204 |
+
"""Calculate cosine similarity between two vectors"""
|
| 205 |
+
vec1 = np.array(vec1)
|
| 206 |
+
vec2 = np.array(vec2)
|
| 207 |
+
|
| 208 |
+
dot_product = np.dot(vec1, vec2)
|
| 209 |
+
norm1 = np.linalg.norm(vec1)
|
| 210 |
+
norm2 = np.linalg.norm(vec2)
|
| 211 |
+
|
| 212 |
+
if norm1 == 0 or norm2 == 0:
|
| 213 |
+
return 0.0
|
| 214 |
+
|
| 215 |
+
return dot_product / (norm1 * norm2)
|
| 216 |
+
|
| 217 |
+
async def build_context(self, query: str, relevant_products: List[Dict]) -> str:
|
| 218 |
+
"""Build context string from relevant products for the LLM"""
|
| 219 |
+
if not relevant_products:
|
| 220 |
+
return "No specific product information available. Use general knowledge about e-commerce and shopping."
|
| 221 |
+
|
| 222 |
+
context_parts = ["Relevant Products Information:"]
|
| 223 |
+
|
| 224 |
+
for i, product in enumerate(relevant_products, 1):
|
| 225 |
+
context_parts.append(f"""
|
| 226 |
+
Product {i}:
|
| 227 |
+
- Name: {product.get('name', 'N/A')}
|
| 228 |
+
- Description: {product.get('description', 'N/A')}
|
| 229 |
+
- Price: ${product.get('price', 'N/A')}
|
| 230 |
+
- Category: {product.get('category', 'N/A')}
|
| 231 |
+
- In Stock: {'Yes' if product.get('in_stock') else 'No'}
|
| 232 |
+
- Features: {', '.join(product.get('features', []))}
|
| 233 |
+
- Tags: {', '.join(product.get('tags', []))}
|
| 234 |
+
""")
|
| 235 |
+
|
| 236 |
+
return "\n".join(context_parts)
|
| 237 |
+
|
| 238 |
+
async def generate_chat_response(self, user_message: str, conversation_history: List[Dict] = None) -> Dict[str, Any]:
|
| 239 |
+
"""Generate response using RAG pipeline"""
|
| 240 |
+
# Classify intent
|
| 241 |
+
intent_result = await gemini_service.classify_intent(user_message)
|
| 242 |
+
|
| 243 |
+
# Retrieve relevant products (already transformed)
|
| 244 |
+
relevant_products = await self.retrieve_relevant_products(user_message, limit=3)
|
| 245 |
+
|
| 246 |
+
# Build context
|
| 247 |
+
context = await self.build_context(user_message, relevant_products)
|
| 248 |
+
|
| 249 |
+
# Generate response using Gemini with context
|
| 250 |
+
response = await gemini_service.generate_response(user_message, context)
|
| 251 |
+
|
| 252 |
+
# Generate suggested questions based on intent
|
| 253 |
+
suggested_questions = await self._generate_suggested_questions(intent_result['intent'], relevant_products)
|
| 254 |
+
|
| 255 |
+
return {
|
| 256 |
+
"response": response,
|
| 257 |
+
"relevant_products": relevant_products,
|
| 258 |
+
"suggested_questions": suggested_questions,
|
| 259 |
+
"intent": intent_result['intent'],
|
| 260 |
+
"confidence": intent_result['confidence']
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
async def _generate_suggested_questions(self, intent: str, products: List[Dict]) -> List[str]:
|
| 264 |
+
"""Generate context-aware suggested questions"""
|
| 265 |
+
base_questions = {
|
| 266 |
+
"product_inquiry": ["Show me more products", "What are your best sellers?", "Any current deals?"],
|
| 267 |
+
"pricing": ["Do you offer discounts?", "What's the return policy?", "Any bundle deals?"],
|
| 268 |
+
"shipping": ["How long does delivery take?", "Do you ship internationally?", "What are shipping costs?"],
|
| 269 |
+
"returns": ["How do I return an item?", "What's your warranty policy?", "Do you offer exchanges?"],
|
| 270 |
+
"support": ["Contact customer service", "Store locations", "Business hours"],
|
| 271 |
+
"default": ["Best sellers", "Current deals", "Shipping info", "Return policy"]
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
questions = base_questions.get(intent, base_questions["default"])
|
| 275 |
+
|
| 276 |
+
# Add product-specific questions if we have products
|
| 277 |
+
if products and intent == "product_inquiry":
|
| 278 |
+
categories = list(set(str(p.get('category')) for p in products))
|
| 279 |
+
if categories:
|
| 280 |
+
questions = [f"Show me more {cat} products" for cat in categories[:2]] + questions
|
| 281 |
+
|
| 282 |
+
return questions[:4]
|
| 283 |
+
|
| 284 |
+
# RAG system instance
|
| 285 |
+
rag_system = RAGSystem()
|
requirements.txt
ADDED
|
Binary file (3 kB). View file
|
|
|
run.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import uvicorn
|
| 2 |
+
from config import settings
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
uvicorn.run(
|
| 6 |
+
"main:app",
|
| 7 |
+
host="0.0.0.0",
|
| 8 |
+
port=settings.port,
|
| 9 |
+
reload=True, # Enable auto-reload during development
|
| 10 |
+
log_level="info"
|
| 11 |
+
)
|