Spaces:
Sleeping
Sleeping
Crcs1225 commited on
Commit ·
1333c38
1
Parent(s): ccabb90
new rag llm
Browse files- config.py +14 -9
- database.py +80 -109
- gemini_service.py +0 -106
- main.py +193 -174
- models.py +45 -81
- rag_system.py +139 -262
- run.py +2 -8
config.py
CHANGED
|
@@ -2,15 +2,20 @@ from pydantic_settings import BaseSettings
|
|
| 2 |
from typing import Optional
|
| 3 |
|
| 4 |
class Settings(BaseSettings):
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
class Config:
|
| 14 |
-
env_file =
|
|
|
|
|
|
|
| 15 |
|
| 16 |
-
settings = Settings()
|
|
|
|
| 2 |
from typing import Optional
|
| 3 |
|
| 4 |
class Settings(BaseSettings):
|
| 5 |
+
# MongoDB Atlas
|
| 6 |
+
MONGODB_URI: str
|
| 7 |
+
DATABASE_NAME: str = "marketplace"
|
| 8 |
+
COLLECTION_NAME: str = "product"
|
| 9 |
+
|
| 10 |
+
# Gemini API
|
| 11 |
+
GEMINI_API_KEY: str
|
| 12 |
+
|
| 13 |
+
# Server
|
| 14 |
+
HOST: str = "0.0.0.0"
|
| 15 |
+
PORT: int = 7860
|
| 16 |
|
| 17 |
class Config:
|
| 18 |
+
env_file = ".env"
|
| 19 |
+
|
| 20 |
+
settings = Settings()
|
| 21 |
|
|
|
database.py
CHANGED
|
@@ -1,117 +1,88 @@
|
|
| 1 |
-
from datetime import datetime
|
| 2 |
import motor.motor_asyncio
|
| 3 |
from bson import ObjectId
|
| 4 |
-
from typing import List,
|
|
|
|
| 5 |
from config import settings
|
| 6 |
-
import json
|
| 7 |
|
| 8 |
-
class
|
| 9 |
def __init__(self):
|
| 10 |
-
self.client =
|
| 11 |
-
self.db =
|
| 12 |
-
self.
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
async def
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
async def create_product(self, product_data: Dict) -> str:
|
| 33 |
-
result = await self.products.insert_one(product_data)
|
| 34 |
-
return str(result.inserted_id)
|
| 35 |
-
|
| 36 |
-
async def update_product(self, product_id: str, update_data: Dict) -> bool:
|
| 37 |
-
result = await self.products.update_one(
|
| 38 |
-
{"_id": ObjectId(product_id)},
|
| 39 |
-
{"$set": update_data}
|
| 40 |
-
)
|
| 41 |
-
return result.modified_count > 0
|
| 42 |
-
|
| 43 |
-
async def delete_product(self, product_id: str) -> bool:
|
| 44 |
-
result = await self.products.delete_one({"_id": ObjectId(product_id)})
|
| 45 |
-
return result.deleted_count > 0
|
| 46 |
-
|
| 47 |
-
async def get_conversation(self, conversation_id: str) -> Optional[Dict]:
|
| 48 |
-
return await self.conversations.find_one({"_id": ObjectId(conversation_id)})
|
| 49 |
-
|
| 50 |
-
async def create_conversation(self, user_id: str) -> str:
|
| 51 |
-
conversation_data = {
|
| 52 |
-
"user_id": user_id,
|
| 53 |
-
"messages": [],
|
| 54 |
-
"created_at": datetime.now(),
|
| 55 |
-
"updated_at": datetime.now()
|
| 56 |
-
}
|
| 57 |
-
result = await self.conversations.insert_one(conversation_data)
|
| 58 |
-
return str(result.inserted_id)
|
| 59 |
-
|
| 60 |
-
async def add_message_to_conversation(self, conversation_id: str, message: Dict) -> bool:
|
| 61 |
-
result = await self.conversations.update_one(
|
| 62 |
-
{"_id": ObjectId(conversation_id)},
|
| 63 |
-
{
|
| 64 |
-
"$push": {"messages": message},
|
| 65 |
-
"$set": {"updated_at": datetime.now()}
|
| 66 |
-
}
|
| 67 |
-
)
|
| 68 |
-
return result.modified_count > 0
|
| 69 |
-
|
| 70 |
-
async def get_user_conversations(self, user_id: str, limit: int = 10) -> List[Dict]:
|
| 71 |
-
cursor = self.conversations.find({"user_id": user_id}).sort("updated_at", -1).limit(limit)
|
| 72 |
-
return await cursor.to_list(length=limit)
|
| 73 |
-
|
| 74 |
-
async def store_embedding(self, text: str, embedding: List[float], metadata: Dict) -> str:
|
| 75 |
-
doc = {
|
| 76 |
-
"text": text,
|
| 77 |
-
"embedding": embedding,
|
| 78 |
-
"metadata": metadata,
|
| 79 |
-
"created_at": datetime.now()
|
| 80 |
-
}
|
| 81 |
-
result = await self.embeddings.insert_one(doc)
|
| 82 |
-
return str(result.inserted_id)
|
| 83 |
-
|
| 84 |
-
async def find_similar_embeddings(self, embedding: List[float], limit: int = 5) -> List[Dict]:
|
| 85 |
-
# This is a simplified version - in production, you'd use vector search
|
| 86 |
-
pipeline = [
|
| 87 |
-
{
|
| 88 |
-
"$addFields": {
|
| 89 |
-
"similarity": {
|
| 90 |
-
"$sqrt": {
|
| 91 |
-
"$sum": {
|
| 92 |
-
"$map": {
|
| 93 |
-
"input": {"$range": [0, {"$size": "$embedding"}]},
|
| 94 |
-
"as": "idx",
|
| 95 |
-
"in": {
|
| 96 |
-
"$pow": [
|
| 97 |
-
{"$subtract": [
|
| 98 |
-
{"$arrayElemAt": ["$embedding", "$$idx"]},
|
| 99 |
-
{"$arrayElemAt": [embedding, "$$idx"]}
|
| 100 |
-
]},
|
| 101 |
-
2
|
| 102 |
-
]
|
| 103 |
-
}
|
| 104 |
-
}
|
| 105 |
-
}
|
| 106 |
-
}
|
| 107 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
}
|
| 109 |
-
}
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
|
|
|
| 115 |
|
| 116 |
-
#
|
| 117 |
-
db =
|
|
|
|
|
|
|
| 1 |
import motor.motor_asyncio
|
| 2 |
from bson import ObjectId
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
import numpy as np
|
| 5 |
from config import settings
|
|
|
|
| 6 |
|
| 7 |
+
class Database:
|
| 8 |
def __init__(self):
|
| 9 |
+
self.client = None
|
| 10 |
+
self.db = None
|
| 11 |
+
self.collection = None
|
| 12 |
+
|
| 13 |
+
async def connect(self):
|
| 14 |
+
self.client = motor.motor_asyncio.AsyncIOMotorClient(settings.MONGODB_URI)
|
| 15 |
+
self.db = self.client[settings.DATABASE_NAME]
|
| 16 |
+
self.collection = self.db[settings.COLLECTION_NAME]
|
| 17 |
+
|
| 18 |
+
async def similarity_search(self, query_embedding: List[float], limit: int = 3) -> List[Dict]:
|
| 19 |
+
"""Search for similar products using vector similarity"""
|
| 20 |
+
try:
|
| 21 |
+
# First try vector search if index exists
|
| 22 |
+
pipeline = [
|
| 23 |
+
{
|
| 24 |
+
"$vectorSearch": {
|
| 25 |
+
"index": "vector_index", # Your vector index name
|
| 26 |
+
"path": "embedding",
|
| 27 |
+
"queryVector": query_embedding,
|
| 28 |
+
"numCandidates": 100,
|
| 29 |
+
"limit": limit
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
}
|
| 31 |
+
},
|
| 32 |
+
{
|
| 33 |
+
"$project": {
|
| 34 |
+
"_id": 1,
|
| 35 |
+
"title": 1,
|
| 36 |
+
"category": 1,
|
| 37 |
+
"product_description": 1,
|
| 38 |
+
"final_price": 1,
|
| 39 |
+
"score": {"$meta": "vectorSearchScore"}
|
| 40 |
+
}
|
| 41 |
+
}
|
| 42 |
+
]
|
| 43 |
+
|
| 44 |
+
cursor = self.collection.aggregate(pipeline)
|
| 45 |
+
results = []
|
| 46 |
+
async for doc in cursor:
|
| 47 |
+
results.append({
|
| 48 |
+
"id": str(doc["_id"]),
|
| 49 |
+
"content": f"Product: {doc.get('title', 'N/A')}. Description: {doc.get('product_description', 'N/A')}. Category: {doc.get('category', 'N/A')}. Price: {doc.get('final_price', 'N/A')}.",
|
| 50 |
+
"source": doc.get('title', 'product_database'),
|
| 51 |
+
"metadata": {
|
| 52 |
+
"category": doc.get('category', 'N/A'),
|
| 53 |
+
"price": doc.get('final_price', 'N/A'),
|
| 54 |
+
"similarity_score": doc.get('score', 0)
|
| 55 |
+
}
|
| 56 |
+
})
|
| 57 |
+
return results
|
| 58 |
+
except Exception as e:
|
| 59 |
+
print(f"Vector search failed, falling back to text search: {e}")
|
| 60 |
+
# Fallback to text search if vector search fails
|
| 61 |
+
return await self.search_by_category("tops", limit)
|
| 62 |
+
|
| 63 |
+
async def search_by_category(self, category: str, limit: int = 5) -> List[Dict]:
|
| 64 |
+
"""Search products by category (fallback if vector search fails)"""
|
| 65 |
+
cursor = self.collection.find(
|
| 66 |
+
{"category": {"$regex": category, "$options": "i"}}
|
| 67 |
+
).limit(limit)
|
| 68 |
+
|
| 69 |
+
results = []
|
| 70 |
+
async for doc in cursor:
|
| 71 |
+
results.append({
|
| 72 |
+
"id": str(doc["_id"]),
|
| 73 |
+
"content": f"Product: {doc.get('title', 'N/A')}. Description: {doc.get('product_description', 'N/A')}. Category: {doc.get('category', 'N/A')}. Price: {doc.get('final_price', 'N/A')}.",
|
| 74 |
+
"source": doc.get('title', 'product_database'),
|
| 75 |
+
"metadata": {
|
| 76 |
+
"category": doc.get('category', 'N/A'),
|
| 77 |
+
"price": doc.get('final_price', 'N/A')
|
| 78 |
}
|
| 79 |
+
})
|
| 80 |
+
return results
|
| 81 |
+
|
| 82 |
+
async def insert_documents(self, documents: List[Dict]) -> List[str]:
|
| 83 |
+
"""Insert documents into the collection"""
|
| 84 |
+
result = await self.collection.insert_many(documents)
|
| 85 |
+
return [str(id) for id in result.inserted_ids]
|
| 86 |
|
| 87 |
+
# Global database instance
|
| 88 |
+
db = Database()
|
gemini_service.py
DELETED
|
@@ -1,106 +0,0 @@
|
|
| 1 |
-
import google.generativeai as genai
|
| 2 |
-
from typing import List, Dict, Any, Optional
|
| 3 |
-
import asyncio
|
| 4 |
-
import aiohttp
|
| 5 |
-
import json
|
| 6 |
-
from config import settings
|
| 7 |
-
|
| 8 |
-
class GeminiService:
|
| 9 |
-
def __init__(self):
|
| 10 |
-
genai.configure(api_key=settings.gemini_api_key)
|
| 11 |
-
self.model = genai.GenerativeModel('gemini-2.5-flash')
|
| 12 |
-
|
| 13 |
-
async def generate_response(self, prompt: str, context: str = "") -> str:
|
| 14 |
-
"""Generate response using Gemini API with context"""
|
| 15 |
-
try:
|
| 16 |
-
full_prompt = f"""
|
| 17 |
-
Context Information:
|
| 18 |
-
{context}
|
| 19 |
-
|
| 20 |
-
User Question: {prompt}
|
| 21 |
-
|
| 22 |
-
You are a helpful shopping assistant for Daddy's Shop. Use the context information above to answer the user's question accurately and helpfully. If the context doesn't contain relevant information, use your general knowledge but be honest about limitations.
|
| 23 |
-
|
| 24 |
-
Provide a friendly, professional response focused on helping with shopping needs.
|
| 25 |
-
"""
|
| 26 |
-
|
| 27 |
-
# Run in thread pool since Gemini doesn't have native async support
|
| 28 |
-
loop = asyncio.get_event_loop()
|
| 29 |
-
response = await loop.run_in_executor(
|
| 30 |
-
None,
|
| 31 |
-
lambda: self.model.generate_content(full_prompt)
|
| 32 |
-
)
|
| 33 |
-
|
| 34 |
-
return response.text
|
| 35 |
-
|
| 36 |
-
except Exception as e:
|
| 37 |
-
print(f"Gemini API error: {e}")
|
| 38 |
-
return "I apologize, but I'm having trouble processing your request right now. Please try again later."
|
| 39 |
-
|
| 40 |
-
async def generate_embedding(self, text: str) -> List[float]:
|
| 41 |
-
"""Generate embeddings for text using Gemini"""
|
| 42 |
-
try:
|
| 43 |
-
# Note: Gemini doesn't have direct embedding API, so we'll use a workaround
|
| 44 |
-
# For production, consider using SentenceTransformers or another embedding service
|
| 45 |
-
embedding_model = genai.GenerativeModel('embedding-001')
|
| 46 |
-
|
| 47 |
-
loop = asyncio.get_event_loop()
|
| 48 |
-
result = await loop.run_in_executor(
|
| 49 |
-
None,
|
| 50 |
-
lambda: genai.embed_content(
|
| 51 |
-
model=embedding_model,
|
| 52 |
-
content=text,
|
| 53 |
-
task_type="retrieval_document"
|
| 54 |
-
)
|
| 55 |
-
)
|
| 56 |
-
|
| 57 |
-
return result['embedding']
|
| 58 |
-
|
| 59 |
-
except Exception as e:
|
| 60 |
-
print(f"Embedding generation error: {e}")
|
| 61 |
-
# Fallback to simple embedding (in production, use proper embedding model)
|
| 62 |
-
return [0.0] * 384 # Default dimension
|
| 63 |
-
|
| 64 |
-
async def classify_intent(self, message: str) -> Dict[str, Any]:
|
| 65 |
-
"""Classify user intent using Gemini"""
|
| 66 |
-
prompt = f"""
|
| 67 |
-
Classify the following user message into one of these intents:
|
| 68 |
-
- product_inquiry: Questions about products, features, availability
|
| 69 |
-
- pricing: Questions about costs, discounts, prices
|
| 70 |
-
- shipping: Questions about delivery, shipping costs, timelines
|
| 71 |
-
- returns: Questions about returns, refunds, exchanges
|
| 72 |
-
- support: General customer support, contact information
|
| 73 |
-
- greeting: Hello, hi, greetings
|
| 74 |
-
- unknown: Cannot classify
|
| 75 |
-
|
| 76 |
-
Message: "{message}"
|
| 77 |
-
|
| 78 |
-
Return ONLY a JSON object with:
|
| 79 |
-
{{
|
| 80 |
-
"intent": "classified_intent",
|
| 81 |
-
"confidence": 0.95,
|
| 82 |
-
"entities": ["extracted_entities", "if_any"]
|
| 83 |
-
}}
|
| 84 |
-
"""
|
| 85 |
-
|
| 86 |
-
try:
|
| 87 |
-
loop = asyncio.get_event_loop()
|
| 88 |
-
response = await loop.run_in_executor(
|
| 89 |
-
None,
|
| 90 |
-
lambda: self.model.generate_content(prompt)
|
| 91 |
-
)
|
| 92 |
-
|
| 93 |
-
# Parse JSON response
|
| 94 |
-
import re
|
| 95 |
-
json_match = re.search(r'\{.*\}', response.text, re.DOTALL)
|
| 96 |
-
if json_match:
|
| 97 |
-
return json.loads(json_match.group())
|
| 98 |
-
else:
|
| 99 |
-
return {"intent": "unknown", "confidence": 0.0, "entities": []}
|
| 100 |
-
|
| 101 |
-
except Exception as e:
|
| 102 |
-
print(f"Intent classification error: {e}")
|
| 103 |
-
return {"intent": "unknown", "confidence": 0.0, "entities": []}
|
| 104 |
-
|
| 105 |
-
# Gemini service instance
|
| 106 |
-
gemini_service = GeminiService()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main.py
CHANGED
|
@@ -1,217 +1,236 @@
|
|
| 1 |
-
|
|
|
|
| 2 |
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
-
from
|
| 4 |
import uuid
|
| 5 |
-
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
from database import db
|
| 9 |
-
from
|
| 10 |
-
from
|
| 11 |
-
|
| 12 |
-
SearchRequest, Conversation, ChatMessage
|
| 13 |
-
)
|
| 14 |
|
| 15 |
app = FastAPI(
|
| 16 |
-
title="
|
| 17 |
-
description="
|
| 18 |
version="1.0.0"
|
| 19 |
)
|
| 20 |
|
| 21 |
# CORS middleware
|
| 22 |
app.add_middleware(
|
| 23 |
CORSMiddleware,
|
| 24 |
-
allow_origins=["
|
| 25 |
allow_credentials=True,
|
| 26 |
allow_methods=["*"],
|
| 27 |
allow_headers=["*"],
|
| 28 |
)
|
|
|
|
| 29 |
|
| 30 |
-
@app.get("/")
|
| 31 |
-
async def root():
|
| 32 |
-
"""Health check endpoint"""
|
| 33 |
-
return {
|
| 34 |
-
"message": "Daddy's Shop RAG Chatbot API is running!",
|
| 35 |
-
"version": "1.0.0",
|
| 36 |
-
"status": "healthy"
|
| 37 |
-
}
|
| 38 |
|
| 39 |
-
@app.
|
| 40 |
-
async def
|
| 41 |
-
"""
|
| 42 |
-
|
| 43 |
-
"""
|
| 44 |
try:
|
| 45 |
-
|
| 46 |
-
if not request.conversation_id:
|
| 47 |
-
conversation_id = await db.create_conversation(request.user_id)
|
| 48 |
-
else:
|
| 49 |
-
conversation_id = request.conversation_id
|
| 50 |
-
# Verify conversation exists
|
| 51 |
-
conversation = await db.get_conversation(conversation_id)
|
| 52 |
-
if not conversation:
|
| 53 |
-
conversation_id = await db.create_conversation(request.user_id)
|
| 54 |
-
|
| 55 |
-
# Add user message to conversation
|
| 56 |
-
user_message = ChatMessage(
|
| 57 |
-
sender="user",
|
| 58 |
-
text=request.message
|
| 59 |
-
)
|
| 60 |
-
await db.add_message_to_conversation(conversation_id, user_message.dict())
|
| 61 |
|
| 62 |
-
#
|
| 63 |
-
|
| 64 |
-
history = conversation.get('messages', []) if conversation else []
|
| 65 |
|
| 66 |
-
#
|
| 67 |
-
|
|
|
|
| 68 |
|
| 69 |
-
|
| 70 |
-
bot_message = ChatMessage(
|
| 71 |
-
sender="bot",
|
| 72 |
-
text=rag_result["response"]
|
| 73 |
-
)
|
| 74 |
-
await db.add_message_to_conversation(conversation_id, bot_message.dict())
|
| 75 |
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
except Exception as e:
|
| 86 |
-
|
|
|
|
| 87 |
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
"""Get products with optional category filter"""
|
| 91 |
try:
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
except Exception as e:
|
| 98 |
-
|
| 99 |
|
| 100 |
-
@app.get("/
|
| 101 |
-
async def
|
| 102 |
-
"""
|
| 103 |
-
product = await db.get_product(product_id)
|
| 104 |
-
if not product:
|
| 105 |
-
raise HTTPException(status_code=404, detail="Product not found")
|
| 106 |
-
return product
|
| 107 |
|
| 108 |
-
@app.
|
| 109 |
-
async def
|
| 110 |
-
"""
|
| 111 |
-
try:
|
| 112 |
-
product_data = product.dict()
|
| 113 |
-
product_data["_id"] = str(uuid.uuid4())[:8] # Simple ID generation
|
| 114 |
-
product_id = await db.create_product(product_data)
|
| 115 |
-
return {"message": "Product created successfully", "product_id": product_id}
|
| 116 |
-
except Exception as e:
|
| 117 |
-
raise HTTPException(status_code=500, detail=f"Error creating product: {str(e)}")
|
| 118 |
|
| 119 |
-
@app.post("/
|
| 120 |
-
async def
|
| 121 |
-
"""
|
| 122 |
try:
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 127 |
)
|
| 128 |
-
return products
|
| 129 |
except Exception as e:
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
| 135 |
try:
|
| 136 |
-
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
except Exception as e:
|
| 139 |
-
|
| 140 |
|
| 141 |
-
@app.get("/
|
| 142 |
-
async def
|
| 143 |
-
"""Get
|
| 144 |
-
return {
|
| 145 |
-
"intents": [
|
| 146 |
-
"product_inquiry", "pricing", "shipping",
|
| 147 |
-
"returns", "support", "greeting", "unknown"
|
| 148 |
-
],
|
| 149 |
-
"description": "Intent classification for user messages"
|
| 150 |
-
}
|
| 151 |
-
@app.get("/test-products")
|
| 152 |
-
async def test_products(limit: int = 3):
|
| 153 |
-
"""Test endpoint to see transformed products"""
|
| 154 |
try:
|
| 155 |
-
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
return {
|
| 158 |
-
"
|
| 159 |
-
"
|
| 160 |
-
"
|
| 161 |
-
"transformed_sample": transformed_products[0] if transformed_products else {},
|
| 162 |
-
"all_transformed": transformed_products
|
| 163 |
}
|
| 164 |
except Exception as e:
|
| 165 |
-
|
| 166 |
-
# Sample data initialization
|
| 167 |
-
@app.on_event("startup")
|
| 168 |
-
async def startup_event():
|
| 169 |
-
"""Initialize sample data if needed"""
|
| 170 |
-
try:
|
| 171 |
-
# Check if we have any products
|
| 172 |
-
products = await db.get_all_products(1)
|
| 173 |
-
if not products:
|
| 174 |
-
await _initialize_sample_data()
|
| 175 |
-
except Exception as e:
|
| 176 |
-
print(f"Startup initialization error: {e}")
|
| 177 |
-
|
| 178 |
-
async def _initialize_sample_data():
|
| 179 |
-
"""Initialize sample product data"""
|
| 180 |
-
sample_products = [
|
| 181 |
-
{
|
| 182 |
-
"name": "Wireless Bluetooth Earbuds",
|
| 183 |
-
"description": "High-quality wireless earbuds with noise cancellation and 24-hour battery life.",
|
| 184 |
-
"price": 79.99,
|
| 185 |
-
"category": "electronics",
|
| 186 |
-
"in_stock": True,
|
| 187 |
-
"tags": ["wireless", "bluetooth", "audio", "noise-cancellation"],
|
| 188 |
-
"features": ["Noise Cancellation", "24h Battery", "Water Resistant", "Touch Controls"]
|
| 189 |
-
},
|
| 190 |
-
{
|
| 191 |
-
"name": "Smart Fitness Watch",
|
| 192 |
-
"description": "Advanced fitness tracker with heart rate monitoring, GPS, and smartphone connectivity.",
|
| 193 |
-
"price": 199.99,
|
| 194 |
-
"category": "electronics",
|
| 195 |
-
"in_stock": True,
|
| 196 |
-
"tags": ["fitness", "smartwatch", "health", "tracking"],
|
| 197 |
-
"features": ["Heart Rate Monitor", "GPS", "Sleep Tracking", "Waterproof"]
|
| 198 |
-
},
|
| 199 |
-
{
|
| 200 |
-
"name": "Organic Cotton T-Shirt",
|
| 201 |
-
"description": "Comfortable and sustainable organic cotton t-shirt available in multiple colors.",
|
| 202 |
-
"price": 24.99,
|
| 203 |
-
"category": "clothing",
|
| 204 |
-
"in_stock": True,
|
| 205 |
-
"tags": ["cotton", "organic", "sustainable", "casual"],
|
| 206 |
-
"features": ["Organic Cotton", "Machine Washable", "Multiple Colors"]
|
| 207 |
-
}
|
| 208 |
-
]
|
| 209 |
-
|
| 210 |
-
for product in sample_products:
|
| 211 |
-
await db.create_product(product)
|
| 212 |
|
| 213 |
-
print("Sample data initialized successfully")
|
| 214 |
-
|
| 215 |
if __name__ == "__main__":
|
| 216 |
-
|
| 217 |
-
uvicorn.run(app, host="0.0.0.0", port=settings.port)
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from fastapi import FastAPI, HTTPException, Query
|
| 3 |
from fastapi.middleware.cors import CORSMiddleware
|
| 4 |
+
from fastapi.responses import JSONResponse
|
| 5 |
import uuid
|
| 6 |
+
import uvicorn
|
| 7 |
+
from typing import List, Optional
|
| 8 |
+
import traceback
|
| 9 |
|
| 10 |
+
# Import your existing modules
|
| 11 |
from database import db
|
| 12 |
+
from models import ChatMessage, ChatRequest, ChatResponse, Product, SearchRequest, Conversation, KnowledgeDocument, Document, SourceInfo
|
| 13 |
+
from config import settings
|
| 14 |
+
from rag_system import rag_pipeline
|
|
|
|
|
|
|
| 15 |
|
| 16 |
app = FastAPI(
|
| 17 |
+
title="RAG Chatbot API",
|
| 18 |
+
description="Lightweight RAG Chatbot using MongoDB Atlas and Gemini",
|
| 19 |
version="1.0.0"
|
| 20 |
)
|
| 21 |
|
| 22 |
# CORS middleware
|
| 23 |
app.add_middleware(
|
| 24 |
CORSMiddleware,
|
| 25 |
+
allow_origins=["*"],
|
| 26 |
allow_credentials=True,
|
| 27 |
allow_methods=["*"],
|
| 28 |
allow_headers=["*"],
|
| 29 |
)
|
| 30 |
+
embeddings_generated = False
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
+
@app.on_event("startup")
|
| 34 |
+
async def startup_event():
|
| 35 |
+
"""Run on application startup"""
|
| 36 |
+
global embeddings_generated
|
|
|
|
| 37 |
try:
|
| 38 |
+
print("🚀 Starting RAG Chatbot API...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
+
# Initialize database connection
|
| 41 |
+
await db.connect()
|
|
|
|
| 42 |
|
| 43 |
+
# Check if we need to generate embeddings
|
| 44 |
+
total_docs = await db.collection.count_documents({})
|
| 45 |
+
docs_with_embeddings = await db.collection.count_documents({"embedding": {"$exists": True}})
|
| 46 |
|
| 47 |
+
print(f"📊 Database status: {total_docs} total documents, {docs_with_embeddings} with embeddings")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
+
# If we have documents but no embeddings, generate them
|
| 50 |
+
if total_docs > 0 and docs_with_embeddings == 0:
|
| 51 |
+
print("🔄 No embeddings found. Starting automatic embedding generation...")
|
| 52 |
+
await generate_embeddings_on_startup()
|
| 53 |
+
embeddings_generated = True
|
| 54 |
+
elif docs_with_embeddings > 0:
|
| 55 |
+
print(f"✅ Embeddings already exist for {docs_with_embeddings} documents")
|
| 56 |
+
embeddings_generated = True
|
| 57 |
+
else:
|
| 58 |
+
print("ℹ️ No documents found in database")
|
| 59 |
+
|
| 60 |
+
print("✅ RAG Chatbot API is ready!")
|
| 61 |
|
| 62 |
except Exception as e:
|
| 63 |
+
print(f"❌ Startup error: {e}")
|
| 64 |
+
raise
|
| 65 |
|
| 66 |
+
async def generate_embeddings_on_startup():
|
| 67 |
+
"""Generate embeddings for all documents on startup"""
|
|
|
|
| 68 |
try:
|
| 69 |
+
# Find all documents without embeddings
|
| 70 |
+
cursor = db.collection.find({"embedding": {"$exists": False}})
|
| 71 |
+
documents_without_embeddings = []
|
| 72 |
+
async for doc in cursor:
|
| 73 |
+
documents_without_embeddings.append(doc)
|
| 74 |
+
|
| 75 |
+
if not documents_without_embeddings:
|
| 76 |
+
print("✅ All documents already have embeddings")
|
| 77 |
+
return
|
| 78 |
+
|
| 79 |
+
print(f"🔄 Generating embeddings for {len(documents_without_embeddings)} documents...")
|
| 80 |
+
|
| 81 |
+
updated_count = 0
|
| 82 |
+
errors = 0
|
| 83 |
+
|
| 84 |
+
# Process in smaller batches to avoid timeout
|
| 85 |
+
batch_size = 50
|
| 86 |
+
for i in range(0, len(documents_without_embeddings), batch_size):
|
| 87 |
+
batch = documents_without_embeddings[i:i + batch_size]
|
| 88 |
+
|
| 89 |
+
for doc in batch:
|
| 90 |
+
try:
|
| 91 |
+
# Create meaningful content for embedding
|
| 92 |
+
content_parts = []
|
| 93 |
+
|
| 94 |
+
# Include all relevant text fields
|
| 95 |
+
if doc.get('title'):
|
| 96 |
+
content_parts.append(f"Product: {doc['title']}")
|
| 97 |
+
if doc.get('product_description'):
|
| 98 |
+
content_parts.append(f"Description: {doc['product_description']}")
|
| 99 |
+
if doc.get('category'):
|
| 100 |
+
content_parts.append(f"Category: {doc['category']}")
|
| 101 |
+
|
| 102 |
+
content = ". ".join(content_parts)
|
| 103 |
+
|
| 104 |
+
if content.strip():
|
| 105 |
+
# Generate embedding
|
| 106 |
+
embedding = await rag_pipeline.get_embeddings([content])
|
| 107 |
+
|
| 108 |
+
# Update document with embedding
|
| 109 |
+
await db.collection.update_one(
|
| 110 |
+
{"_id": doc["_id"]},
|
| 111 |
+
{"$set": {"embedding": embedding[0]}}
|
| 112 |
+
)
|
| 113 |
+
updated_count += 1
|
| 114 |
+
|
| 115 |
+
# Progress update every 50 documents
|
| 116 |
+
if updated_count % 50 == 0:
|
| 117 |
+
print(f"✅ Processed {updated_count}/{len(documents_without_embeddings)} documents...")
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
errors += 1
|
| 121 |
+
print(f"❌ Error processing document: {e}")
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
# Small delay between batches
|
| 125 |
+
await asyncio.sleep(1)
|
| 126 |
+
|
| 127 |
+
# Final status
|
| 128 |
+
final_with_embeddings = await db.collection.count_documents({"embedding": {"$exists": True}})
|
| 129 |
+
print(f"🎉 Embedding generation completed! {final_with_embeddings} documents now have embeddings")
|
| 130 |
+
|
| 131 |
except Exception as e:
|
| 132 |
+
print(f"❌ Embedding generation failed: {e}")
|
| 133 |
|
| 134 |
+
@app.get("/")
|
| 135 |
+
async def root():
|
| 136 |
+
return {"message": "RAG Chatbot API is running!", "status": "healthy"}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
+
@app.get("/health")
|
| 139 |
+
async def health_check():
|
| 140 |
+
return {"status": "healthy", "service": "rag-chatbot"}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
|
| 142 |
+
@app.post("/chat")
|
| 143 |
+
async def chat_with_assistant(request: ChatRequest):
|
| 144 |
+
"""Main chat endpoint for product queries"""
|
| 145 |
try:
|
| 146 |
+
print(f"💬 Received chat request: {request.message}")
|
| 147 |
+
response, sources = await rag_pipeline.generate_response(request.message)
|
| 148 |
+
|
| 149 |
+
suggested_questions = rag_pipeline.generate_followup_questions(
|
| 150 |
+
request.message,
|
| 151 |
+
sources
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
# Convert to SourceInfo objects
|
| 155 |
+
source_objects = []
|
| 156 |
+
for product in sources:
|
| 157 |
+
source_objects.append(SourceInfo(
|
| 158 |
+
id=product.get("id", ""),
|
| 159 |
+
name=product.get("source", "Product"),
|
| 160 |
+
category=product.get("metadata", {}).get("category", "N/A"),
|
| 161 |
+
price=str(product.get("metadata", {}).get("price", "N/A")),
|
| 162 |
+
similarity_score=product.get("metadata", {}).get("similarity_score", 0)
|
| 163 |
+
))
|
| 164 |
+
|
| 165 |
+
return ChatResponse(
|
| 166 |
+
response=response,
|
| 167 |
+
sources=source_objects,
|
| 168 |
+
suggested_questions=suggested_questions,
|
| 169 |
+
conversation_id=request.conversation_id
|
| 170 |
)
|
|
|
|
| 171 |
except Exception as e:
|
| 172 |
+
print(f"❌ Error in /chat endpoint: {traceback.format_exc()}")
|
| 173 |
+
raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
|
| 174 |
+
|
| 175 |
+
@app.get("/debug/vector-results")
|
| 176 |
+
async def debug_vector_results(query: str = "tops"):
|
| 177 |
+
"""See exactly what vector search returns"""
|
| 178 |
try:
|
| 179 |
+
# Get embeddings for the query
|
| 180 |
+
query_embedding = await rag_pipeline.get_embeddings([query])
|
| 181 |
+
print(f"🔍 Testing vector search for: '{query}'")
|
| 182 |
+
print(f"📐 Embedding dimensions: {len(query_embedding[0])}")
|
| 183 |
+
|
| 184 |
+
# Perform vector search
|
| 185 |
+
results = await db.similarity_search(query_embedding[0], limit=5)
|
| 186 |
+
|
| 187 |
+
response_data = {
|
| 188 |
+
"query": query,
|
| 189 |
+
"embedding_dimensions": len(query_embedding[0]),
|
| 190 |
+
"results_found": len(results),
|
| 191 |
+
"raw_results": []
|
| 192 |
+
}
|
| 193 |
+
|
| 194 |
+
for i, doc in enumerate(results):
|
| 195 |
+
response_data["raw_results"].append({
|
| 196 |
+
"rank": i + 1,
|
| 197 |
+
"id": doc["id"],
|
| 198 |
+
"content": doc["content"],
|
| 199 |
+
"source": doc.get("source", "unknown"),
|
| 200 |
+
"metadata": doc.get("metadata", {})
|
| 201 |
+
})
|
| 202 |
+
print(f"📄 Result {i+1}: {doc['content'][:100]}...")
|
| 203 |
+
|
| 204 |
+
return response_data
|
| 205 |
+
|
| 206 |
except Exception as e:
|
| 207 |
+
return {"error": str(e), "traceback": traceback.format_exc()}
|
| 208 |
|
| 209 |
+
@app.get("/debug/sample-products")
|
| 210 |
+
async def debug_sample_products(category: str = "tops", limit: int = 5):
|
| 211 |
+
"""Get sample products to see what content is available"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
try:
|
| 213 |
+
cursor = db.collection.find({"category": {"$regex": category, "$options": "i"}}).limit(limit)
|
| 214 |
+
products = []
|
| 215 |
+
async for doc in cursor:
|
| 216 |
+
product_info = {
|
| 217 |
+
"id": str(doc["_id"]),
|
| 218 |
+
"name": doc.get("title", "N/A"),
|
| 219 |
+
"category": doc.get("category", "N/A"),
|
| 220 |
+
"description": doc.get("product_description", "N/A"),
|
| 221 |
+
"price": doc.get("final_price", "N/A"),
|
| 222 |
+
"has_embedding": "embedding" in doc,
|
| 223 |
+
"content_used_for_embedding": f"{doc.get('title', '')} {doc.get('product_description', '')} {doc.get('category', '')}"
|
| 224 |
+
}
|
| 225 |
+
products.append(product_info)
|
| 226 |
+
|
| 227 |
return {
|
| 228 |
+
"category": category,
|
| 229 |
+
"products_found": len(products),
|
| 230 |
+
"products": products
|
|
|
|
|
|
|
| 231 |
}
|
| 232 |
except Exception as e:
|
| 233 |
+
return {"error": str(e)}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 234 |
|
|
|
|
|
|
|
| 235 |
if __name__ == "__main__":
|
| 236 |
+
uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
models.py
CHANGED
|
@@ -1,100 +1,64 @@
|
|
| 1 |
-
from pydantic import BaseModel, Field
|
| 2 |
from typing import List, Optional, Dict, Any
|
| 3 |
from datetime import datetime
|
| 4 |
-
from enum import Enum
|
| 5 |
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
CLOTHING = "clothing"
|
| 9 |
-
HOME = "home"
|
| 10 |
-
BEAUTY = "beauty"
|
| 11 |
-
SPORTS = "sports"
|
| 12 |
-
BOOKS = "books"
|
| 13 |
-
OTHER = "other"
|
| 14 |
-
|
| 15 |
-
class ProductCreate(BaseModel):
|
| 16 |
-
name: str
|
| 17 |
-
description: str = ""
|
| 18 |
-
price: float = 0.0
|
| 19 |
-
category: ProductCategory = ProductCategory.OTHER
|
| 20 |
-
in_stock: bool = True
|
| 21 |
-
tags: List[str] = []
|
| 22 |
-
features: List[str] = []
|
| 23 |
-
image_url: Optional[str] = None
|
| 24 |
-
|
| 25 |
-
@validator('price', pre=True)
|
| 26 |
-
def validate_price(cls, v):
|
| 27 |
-
if v is None:
|
| 28 |
-
return 0.0
|
| 29 |
-
try:
|
| 30 |
-
return float(v)
|
| 31 |
-
except (TypeError, ValueError):
|
| 32 |
-
return 0.0
|
| 33 |
-
|
| 34 |
-
@validator('category', pre=True)
|
| 35 |
-
def validate_category(cls, v):
|
| 36 |
-
if isinstance(v, ProductCategory):
|
| 37 |
-
return v
|
| 38 |
-
try:
|
| 39 |
-
return ProductCategory(v)
|
| 40 |
-
except ValueError:
|
| 41 |
-
return ProductCategory.OTHER
|
| 42 |
-
|
| 43 |
-
class Product(BaseModel):
|
| 44 |
id: str
|
| 45 |
name: str
|
| 46 |
-
|
| 47 |
-
price:
|
| 48 |
-
|
| 49 |
-
in_stock: bool = True
|
| 50 |
-
tags: List[str] = []
|
| 51 |
-
features: List[str] = []
|
| 52 |
-
image_url: Optional[str] = None
|
| 53 |
-
|
| 54 |
-
@validator('price', pre=True)
|
| 55 |
-
def validate_price(cls, v):
|
| 56 |
-
if v is None:
|
| 57 |
-
return 0.0
|
| 58 |
-
try:
|
| 59 |
-
return float(v)
|
| 60 |
-
except (TypeError, ValueError):
|
| 61 |
-
return 0.0
|
| 62 |
-
|
| 63 |
-
@validator('category', pre=True)
|
| 64 |
-
def validate_category(cls, v):
|
| 65 |
-
if isinstance(v, ProductCategory):
|
| 66 |
-
return v
|
| 67 |
-
try:
|
| 68 |
-
return ProductCategory(v)
|
| 69 |
-
except ValueError:
|
| 70 |
-
return ProductCategory.OTHER
|
| 71 |
-
|
| 72 |
-
class ChatMessage(BaseModel):
|
| 73 |
-
sender: str
|
| 74 |
-
text: str
|
| 75 |
-
timestamp: datetime = Field(default_factory=datetime.now)
|
| 76 |
|
|
|
|
| 77 |
class ChatRequest(BaseModel):
|
| 78 |
message: str
|
| 79 |
conversation_id: Optional[str] = None
|
| 80 |
-
user_id: Optional[str] = "anonymous"
|
| 81 |
|
| 82 |
class ChatResponse(BaseModel):
|
| 83 |
response: str
|
| 84 |
-
|
| 85 |
-
suggested_questions: List[str]
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 89 |
|
| 90 |
class SearchRequest(BaseModel):
|
| 91 |
query: str
|
| 92 |
-
category: Optional[
|
| 93 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
|
| 95 |
class Conversation(BaseModel):
|
| 96 |
id: str
|
| 97 |
user_id: str
|
| 98 |
-
messages: List[ChatMessage]
|
| 99 |
-
created_at: datetime
|
| 100 |
-
updated_at: datetime
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel, Field
|
| 2 |
from typing import List, Optional, Dict, Any
|
| 3 |
from datetime import datetime
|
|
|
|
| 4 |
|
| 5 |
+
# Source information model
|
| 6 |
+
class SourceInfo(BaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
id: str
|
| 8 |
name: str
|
| 9 |
+
category: str
|
| 10 |
+
price: str
|
| 11 |
+
similarity_score: float
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
+
# Chat request and response models
|
| 14 |
class ChatRequest(BaseModel):
|
| 15 |
message: str
|
| 16 |
conversation_id: Optional[str] = None
|
|
|
|
| 17 |
|
| 18 |
class ChatResponse(BaseModel):
|
| 19 |
response: str
|
| 20 |
+
sources: List[SourceInfo] # Changed from List[str] to List[SourceInfo]
|
| 21 |
+
suggested_questions: List[str]
|
| 22 |
+
conversation_id: Optional[str] = None # Make this optional
|
| 23 |
+
|
| 24 |
+
# Product models
|
| 25 |
+
class Product(BaseModel):
|
| 26 |
+
id: str
|
| 27 |
+
name: str
|
| 28 |
+
category: str
|
| 29 |
+
description: str
|
| 30 |
+
price: float
|
| 31 |
+
image_url: Optional[str] = None
|
| 32 |
+
tags: List[str] = []
|
| 33 |
|
| 34 |
class SearchRequest(BaseModel):
|
| 35 |
query: str
|
| 36 |
+
category: Optional[str] = None
|
| 37 |
+
limit: int = 20
|
| 38 |
+
|
| 39 |
+
# Conversation models
|
| 40 |
+
class ChatMessage(BaseModel):
|
| 41 |
+
role: str # "user" or "assistant"
|
| 42 |
+
content: str
|
| 43 |
+
timestamp: datetime
|
| 44 |
|
| 45 |
class Conversation(BaseModel):
|
| 46 |
id: str
|
| 47 |
user_id: str
|
| 48 |
+
messages: List[ChatMessage]
|
| 49 |
+
created_at: datetime
|
| 50 |
+
updated_at: datetime
|
| 51 |
+
|
| 52 |
+
# Knowledge base models
|
| 53 |
+
class Document(BaseModel):
|
| 54 |
+
content: str
|
| 55 |
+
metadata: Dict[str, Any] = {}
|
| 56 |
+
source: str = "upload"
|
| 57 |
+
|
| 58 |
+
class KnowledgeDocument(BaseModel):
|
| 59 |
+
id: str
|
| 60 |
+
content: str
|
| 61 |
+
embedding: List[float]
|
| 62 |
+
metadata: Dict[str, Any]
|
| 63 |
+
source: str
|
| 64 |
+
created_at: datetime
|
rag_system.py
CHANGED
|
@@ -1,285 +1,162 @@
|
|
| 1 |
-
|
| 2 |
-
from database import db
|
| 3 |
-
from gemini_service import gemini_service
|
| 4 |
-
from models import Product, ProductCategory
|
| 5 |
-
import numpy as np
|
| 6 |
from sentence_transformers import SentenceTransformer
|
|
|
|
| 7 |
import asyncio
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
class
|
| 10 |
def __init__(self):
|
|
|
|
| 11 |
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
def
|
| 14 |
-
"""
|
| 15 |
try:
|
| 16 |
-
#
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
| 18 |
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
"in_stock": True, # Default to True since we don't have this field
|
| 27 |
-
"tags": self._extract_tags(product_doc),
|
| 28 |
-
"features": self._extract_features(product_doc),
|
| 29 |
-
"image_url": self._get_first_image(product_doc.get('images', '')),
|
| 30 |
-
}
|
| 31 |
-
|
| 32 |
-
return transformed
|
| 33 |
|
|
|
|
| 34 |
except Exception as e:
|
| 35 |
-
print(f"Error
|
| 36 |
-
#
|
| 37 |
-
return
|
| 38 |
-
"id": str(product_doc.get('_id', 'unknown')),
|
| 39 |
-
"name": product_doc.get('title', 'Unnamed Product'),
|
| 40 |
-
"description": "Product information unavailable",
|
| 41 |
-
"price": 0.0,
|
| 42 |
-
"category": ProductCategory.OTHER,
|
| 43 |
-
"in_stock": True,
|
| 44 |
-
"tags": [],
|
| 45 |
-
"features": [],
|
| 46 |
-
}
|
| 47 |
-
|
| 48 |
-
def _map_category(self, raw_category: str) -> ProductCategory:
|
| 49 |
-
"""Map raw category string to ProductCategory enum"""
|
| 50 |
-
if not raw_category:
|
| 51 |
-
return ProductCategory.OTHER
|
| 52 |
-
|
| 53 |
-
category_lower = raw_category.lower()
|
| 54 |
-
|
| 55 |
-
# Map based on your actual category values
|
| 56 |
-
if any(keyword in category_lower for keyword in ['electronic', 'tech', 'computer', 'phone']):
|
| 57 |
-
return ProductCategory.ELECTRONICS
|
| 58 |
-
elif any(keyword in category_lower for keyword in ['cloth', 'fashion', 'wear', 'top', 'dress', 'shirt', 'jeans']):
|
| 59 |
-
return ProductCategory.CLOTHING
|
| 60 |
-
elif any(keyword in category_lower for keyword in ['home', 'garden', 'furniture', 'decor']):
|
| 61 |
-
return ProductCategory.HOME
|
| 62 |
-
elif any(keyword in category_lower for keyword in ['beauty', 'cosmetic', 'skin', 'hair', 'cream', 'mask', 'makeup']):
|
| 63 |
-
return ProductCategory.BEAUTY
|
| 64 |
-
elif any(keyword in category_lower for keyword in ['sport', 'fitness', 'exercise', 'gym']):
|
| 65 |
-
return ProductCategory.SPORTS
|
| 66 |
-
elif any(keyword in category_lower for keyword in ['book', 'literature']):
|
| 67 |
-
return ProductCategory.BOOKS
|
| 68 |
-
else:
|
| 69 |
-
return ProductCategory.OTHER
|
| 70 |
|
| 71 |
-
def
|
| 72 |
-
"""Extract
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
tags.append('Cotton')
|
| 84 |
-
if 'Polyester' in details:
|
| 85 |
-
tags.append('Polyester')
|
| 86 |
|
| 87 |
-
return
|
| 88 |
|
| 89 |
-
def
|
| 90 |
-
"""
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
return
|
| 120 |
|
| 121 |
-
def
|
| 122 |
-
"""
|
| 123 |
-
if not images_str:
|
| 124 |
-
return None
|
| 125 |
-
images = images_str.split(',')
|
| 126 |
-
return images[0].strip() if images else None
|
| 127 |
-
|
| 128 |
-
# ... rest of your existing methods remain the same
|
| 129 |
-
async def retrieve_relevant_products(self, query: str, category: Optional[str] = None, limit: int = 5) -> List[Dict]:
|
| 130 |
-
"""Retrieve and transform relevant products"""
|
| 131 |
try:
|
| 132 |
-
#
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
# Get all products or filtered by category
|
| 136 |
-
if category:
|
| 137 |
-
products = await db.get_products_by_category(category, limit=50)
|
| 138 |
-
else:
|
| 139 |
-
products = await db.get_all_products(limit=100)
|
| 140 |
|
| 141 |
-
#
|
| 142 |
-
|
| 143 |
|
| 144 |
-
#
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
product_text = f"{product.get('name', '')} {product.get('description', '')} {' '.join(product.get('tags', []))}"
|
| 148 |
-
product_embedding = await self._get_embedding(product_text)
|
| 149 |
-
|
| 150 |
-
similarity = self._cosine_similarity(query_embedding, product_embedding)
|
| 151 |
-
scored_products.append((product, similarity))
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
return [product for product, score in scored_products[:limit]]
|
| 156 |
|
| 157 |
except Exception as e:
|
| 158 |
-
print(f"
|
| 159 |
-
|
| 160 |
-
return
|
| 161 |
|
| 162 |
-
|
| 163 |
-
"""
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
# Transform all products first
|
| 172 |
-
transformed_products = [self._transform_product_doc(product) for product in products]
|
| 173 |
-
|
| 174 |
-
scored_products = []
|
| 175 |
-
for product in transformed_products:
|
| 176 |
-
score = 0
|
| 177 |
-
product_text = f"{product.get('name', '')} {product.get('description', '')} {' '.join(product.get('tags', []))}".lower()
|
| 178 |
-
|
| 179 |
-
for term in search_terms:
|
| 180 |
-
if term in product_text:
|
| 181 |
-
score += 1
|
| 182 |
-
if term in product.get('name', '').lower():
|
| 183 |
-
score += 2 # Higher weight for name matches
|
| 184 |
-
|
| 185 |
-
if score > 0:
|
| 186 |
-
scored_products.append((product, score))
|
| 187 |
-
|
| 188 |
-
scored_products.sort(key=lambda x: x[1], reverse=True)
|
| 189 |
-
return [product for product, score in scored_products[:limit]]
|
| 190 |
-
|
| 191 |
-
async def _get_embedding(self, text: str) -> List[float]:
|
| 192 |
-
"""Get embedding for text using available methods"""
|
| 193 |
-
try:
|
| 194 |
-
embedding = await gemini_service.generate_embedding(text)
|
| 195 |
-
if embedding and len(embedding) > 10:
|
| 196 |
-
return embedding
|
| 197 |
-
except:
|
| 198 |
-
pass
|
| 199 |
-
|
| 200 |
-
embedding = self.embedding_model.encode(text)
|
| 201 |
-
return embedding.tolist()
|
| 202 |
-
|
| 203 |
-
def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
|
| 204 |
-
"""Calculate cosine similarity between two vectors"""
|
| 205 |
-
vec1 = np.array(vec1)
|
| 206 |
-
vec2 = np.array(vec2)
|
| 207 |
-
|
| 208 |
-
dot_product = np.dot(vec1, vec2)
|
| 209 |
-
norm1 = np.linalg.norm(vec1)
|
| 210 |
-
norm2 = np.linalg.norm(vec2)
|
| 211 |
-
|
| 212 |
-
if norm1 == 0 or norm2 == 0:
|
| 213 |
-
return 0.0
|
| 214 |
-
|
| 215 |
-
return dot_product / (norm1 * norm2)
|
| 216 |
-
|
| 217 |
-
async def build_context(self, query: str, relevant_products: List[Dict]) -> str:
|
| 218 |
-
"""Build context string from relevant products for the LLM"""
|
| 219 |
-
if not relevant_products:
|
| 220 |
-
return "No specific product information available. Use general knowledge about e-commerce and shopping."
|
| 221 |
-
|
| 222 |
-
context_parts = ["Relevant Products Information:"]
|
| 223 |
-
|
| 224 |
-
for i, product in enumerate(relevant_products, 1):
|
| 225 |
-
context_parts.append(f"""
|
| 226 |
-
Product {i}:
|
| 227 |
-
- Name: {product.get('name', 'N/A')}
|
| 228 |
-
- Description: {product.get('description', 'N/A')}
|
| 229 |
-
- Price: ${product.get('price', 'N/A')}
|
| 230 |
-
- Category: {product.get('category', 'N/A')}
|
| 231 |
-
- In Stock: {'Yes' if product.get('in_stock') else 'No'}
|
| 232 |
-
- Features: {', '.join(product.get('features', []))}
|
| 233 |
-
- Tags: {', '.join(product.get('tags', []))}
|
| 234 |
-
""")
|
| 235 |
-
|
| 236 |
-
return "\n".join(context_parts)
|
| 237 |
-
|
| 238 |
-
async def generate_chat_response(self, user_message: str, conversation_history: List[Dict] = None) -> Dict[str, Any]:
|
| 239 |
-
"""Generate response using RAG pipeline"""
|
| 240 |
-
# Classify intent
|
| 241 |
-
intent_result = await gemini_service.classify_intent(user_message)
|
| 242 |
-
|
| 243 |
-
# Retrieve relevant products (already transformed)
|
| 244 |
-
relevant_products = await self.retrieve_relevant_products(user_message, limit=3)
|
| 245 |
-
|
| 246 |
-
# Build context
|
| 247 |
-
context = await self.build_context(user_message, relevant_products)
|
| 248 |
-
|
| 249 |
-
# Generate response using Gemini with context
|
| 250 |
-
response = await gemini_service.generate_response(user_message, context)
|
| 251 |
-
|
| 252 |
-
# Generate suggested questions based on intent
|
| 253 |
-
suggested_questions = await self._generate_suggested_questions(intent_result['intent'], relevant_products)
|
| 254 |
-
|
| 255 |
-
return {
|
| 256 |
-
"response": response,
|
| 257 |
-
"relevant_products": relevant_products,
|
| 258 |
-
"suggested_questions": suggested_questions,
|
| 259 |
-
"intent": intent_result['intent'],
|
| 260 |
-
"confidence": intent_result['confidence']
|
| 261 |
-
}
|
| 262 |
-
|
| 263 |
-
async def _generate_suggested_questions(self, intent: str, products: List[Dict]) -> List[str]:
|
| 264 |
-
"""Generate context-aware suggested questions"""
|
| 265 |
-
base_questions = {
|
| 266 |
-
"product_inquiry": ["Show me more products", "What are your best sellers?", "Any current deals?"],
|
| 267 |
-
"pricing": ["Do you offer discounts?", "What's the return policy?", "Any bundle deals?"],
|
| 268 |
-
"shipping": ["How long does delivery take?", "Do you ship internationally?", "What are shipping costs?"],
|
| 269 |
-
"returns": ["How do I return an item?", "What's your warranty policy?", "Do you offer exchanges?"],
|
| 270 |
-
"support": ["Contact customer service", "Store locations", "Business hours"],
|
| 271 |
-
"default": ["Best sellers", "Current deals", "Shipping info", "Return policy"]
|
| 272 |
-
}
|
| 273 |
-
|
| 274 |
-
questions = base_questions.get(intent, base_questions["default"])
|
| 275 |
-
|
| 276 |
-
# Add product-specific questions if we have products
|
| 277 |
-
if products and intent == "product_inquiry":
|
| 278 |
-
categories = list(set(str(p.get('category')) for p in products))
|
| 279 |
-
if categories:
|
| 280 |
-
questions = [f"Show me more {cat} products" for cat in categories[:2]] + questions
|
| 281 |
-
|
| 282 |
-
return questions[:4]
|
| 283 |
|
| 284 |
-
#
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import google.generativeai as genai
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from sentence_transformers import SentenceTransformer
|
| 3 |
+
from typing import List, Tuple, Dict, Any
|
| 4 |
import asyncio
|
| 5 |
+
from database import db
|
| 6 |
+
from config import settings
|
| 7 |
|
| 8 |
+
class ProductRAGPipeline:
|
| 9 |
def __init__(self):
|
| 10 |
+
# Initialize embedding model
|
| 11 |
self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
|
| 12 |
+
|
| 13 |
+
# Initialize Gemini
|
| 14 |
+
genai.configure(api_key=settings.GEMINI_API_KEY)
|
| 15 |
+
self.gemini_model = genai.GenerativeModel('gemini-2.5-flash')
|
| 16 |
+
|
| 17 |
+
# Enhanced personality for shopping assistant
|
| 18 |
+
self.personality_traits = """
|
| 19 |
+
You are a friendly, knowledgeable shopping assistant for a fashion e-commerce store. Your personality traits:
|
| 20 |
+
- Warm, approachable, and enthusiastic about fashion
|
| 21 |
+
- Helpful and patient with customer queries
|
| 22 |
+
- Knowledgeable about products, styles, and fashion trends
|
| 23 |
+
- Casual but professional tone, like a friendly store assistant
|
| 24 |
+
- Use emojis occasionally to express emotion (but don't overdo it)
|
| 25 |
+
- Ask follow-up questions to better understand customer needs
|
| 26 |
+
- Be concise but thorough in product recommendations
|
| 27 |
+
- Always mention key product features, price, and benefits
|
| 28 |
+
- If you suggest multiple products, compare them briefly
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
|
| 32 |
+
"""Get embeddings for texts (async wrapper)"""
|
| 33 |
+
loop = asyncio.get_event_loop()
|
| 34 |
+
embeddings = await loop.run_in_executor(
|
| 35 |
+
None, self.embedding_model.encode, texts
|
| 36 |
+
)
|
| 37 |
+
return embeddings.tolist()
|
| 38 |
|
| 39 |
+
async def retrieve_relevant_products(self, query: str, limit: int = 3) -> List[Dict]:
|
| 40 |
+
"""Retrieve relevant products using vector search with fallback"""
|
| 41 |
try:
|
| 42 |
+
# Try vector search first
|
| 43 |
+
query_embedding = await self.get_embeddings([query])
|
| 44 |
+
print(f"🔍 Performing vector search with embedding dim: {len(query_embedding[0])}")
|
| 45 |
+
relevant_docs = await db.similarity_search(query_embedding[0], limit=limit)
|
| 46 |
+
print(f"✅ Vector search returned {len(relevant_docs)} results")
|
| 47 |
|
| 48 |
+
if not relevant_docs:
|
| 49 |
+
print("🔄 No results from vector search, trying category-based search")
|
| 50 |
+
# Fallback to category-based search
|
| 51 |
+
category_keywords = self._extract_category_from_query(query)
|
| 52 |
+
if category_keywords:
|
| 53 |
+
relevant_docs = await db.search_by_category(category_keywords[0], limit=limit)
|
| 54 |
+
print(f"✅ Category search returned {len(relevant_docs)} results")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
+
return relevant_docs
|
| 57 |
except Exception as e:
|
| 58 |
+
print(f"❌ Error in vector search: {e}")
|
| 59 |
+
# Final fallback to generic product search
|
| 60 |
+
return await db.search_by_category("tops", limit=limit)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
+
def _extract_category_from_query(self, query: str) -> List[str]:
|
| 63 |
+
"""Extract potential category keywords from user query"""
|
| 64 |
+
query_lower = query.lower()
|
| 65 |
+
categories = []
|
| 66 |
+
|
| 67 |
+
category_mapping = {
|
| 68 |
+
'tops': ['top', 'shirt', 'blouse', 't-shirt', 'tshirt', 'crop top', 'spaghetti'],
|
| 69 |
+
'bottoms': ['pant', 'jeans', 'trouser', 'leggings', 'skirt', 'short'],
|
| 70 |
+
'dresses': ['dress', 'gown', 'frock'],
|
| 71 |
+
'outerwear': ['jacket', 'sweater', 'hoodie', 'cardigan', 'coat'],
|
| 72 |
+
'accessories': ['bag', 'jewelry', 'scarf', 'hat', 'belt']
|
| 73 |
+
}
|
| 74 |
|
| 75 |
+
for category, keywords in category_mapping.items():
|
| 76 |
+
if any(keyword in query_lower for keyword in keywords):
|
| 77 |
+
categories.append(category)
|
|
|
|
|
|
|
|
|
|
| 78 |
|
| 79 |
+
return categories if categories else ['tops'] # Default to tops
|
| 80 |
|
| 81 |
+
def create_product_prompt(self, query: str, products: List[Dict]) -> str:
|
| 82 |
+
"""Create context-aware prompt with product information"""
|
| 83 |
+
if products:
|
| 84 |
+
context = "AVAILABLE PRODUCTS:\n"
|
| 85 |
+
for i, product in enumerate(products, 1):
|
| 86 |
+
context += f"{i}. {product['content']}\n"
|
| 87 |
+
else:
|
| 88 |
+
context = "No specific product information available at the moment."
|
| 89 |
+
|
| 90 |
+
prompt = f"""
|
| 91 |
+
{self.personality_traits}
|
| 92 |
+
|
| 93 |
+
{context}
|
| 94 |
+
|
| 95 |
+
USER QUESTION: {query}
|
| 96 |
+
|
| 97 |
+
INSTRUCTIONS:
|
| 98 |
+
1. Answer based primarily on the provided product information
|
| 99 |
+
2. If suggesting products, mention:
|
| 100 |
+
- Key features and benefits
|
| 101 |
+
- Price (if available)
|
| 102 |
+
- Why it might suit the user's needs
|
| 103 |
+
3. Be conversational and helpful
|
| 104 |
+
4. If the exact answer isn't in the products, use your general knowledge but be honest about limitations
|
| 105 |
+
5. Keep responses concise but complete (2-4 sentences usually)
|
| 106 |
+
6. Always maintain a friendly, shopping assistant tone
|
| 107 |
+
7. If multiple products are relevant, compare them briefly
|
| 108 |
+
|
| 109 |
+
SHOPPING ASSISTANT RESPONSE:
|
| 110 |
+
"""
|
| 111 |
+
return prompt
|
| 112 |
|
| 113 |
+
async def generate_response(self, query: str) -> Tuple[str, List[Dict]]:
|
| 114 |
+
"""Generate response using product RAG pipeline"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
try:
|
| 116 |
+
# Retrieve relevant products
|
| 117 |
+
relevant_products = await self.retrieve_relevant_products(query)
|
| 118 |
+
print(f"📦 Retrieved {len(relevant_products)} relevant products")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 119 |
|
| 120 |
+
# Create context-aware prompt
|
| 121 |
+
prompt = self.create_product_prompt(query, relevant_products)
|
| 122 |
|
| 123 |
+
# Generate response using Gemini
|
| 124 |
+
response = self.gemini_model.generate_content(prompt)
|
| 125 |
+
response_text = response.text.strip()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
|
| 127 |
+
print(f"🤖 Generated response: {response_text[:100]}...")
|
| 128 |
+
return response_text, relevant_products
|
|
|
|
| 129 |
|
| 130 |
except Exception as e:
|
| 131 |
+
print(f"❌ Error generating response: {e}")
|
| 132 |
+
fallback_msg = "I apologize, but I'm having trouble accessing our product information right now. Please try again in a moment or contact our customer service for immediate assistance. 😊"
|
| 133 |
+
return fallback_msg, []
|
| 134 |
|
| 135 |
+
def generate_followup_questions(self, query: str, products: List[Dict]) -> List[str]:
|
| 136 |
+
"""Generate context-aware follow-up questions"""
|
| 137 |
+
base_questions = [
|
| 138 |
+
"Tell me more about this product",
|
| 139 |
+
"What are the alternatives in different colors?",
|
| 140 |
+
"Do you have similar items in different price ranges?",
|
| 141 |
+
"What's the sizing like for these products?",
|
| 142 |
+
"Are any of these currently on sale?"
|
| 143 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
|
| 145 |
+
# Context-aware questions
|
| 146 |
+
query_lower = query.lower()
|
| 147 |
+
if any(word in query_lower for word in ['price', 'cost', 'expensive', 'cheap']):
|
| 148 |
+
base_questions.extend([
|
| 149 |
+
"What's the price range for similar items?",
|
| 150 |
+
"Are there any ongoing discounts?"
|
| 151 |
+
])
|
| 152 |
+
|
| 153 |
+
if any(word in query_lower for word in ['color', 'colour', 'pattern']):
|
| 154 |
+
base_questions.extend([
|
| 155 |
+
"What other colors are available?",
|
| 156 |
+
"Do you have this in solid colors vs patterns?"
|
| 157 |
+
])
|
| 158 |
+
|
| 159 |
+
return base_questions[:5] # Return top 5 questions
|
| 160 |
+
|
| 161 |
+
# Global RAG pipeline instance
|
| 162 |
+
rag_pipeline = ProductRAGPipeline()
|
run.py
CHANGED
|
@@ -1,11 +1,5 @@
|
|
|
|
|
| 1 |
import uvicorn
|
| 2 |
-
from config import settings
|
| 3 |
|
| 4 |
if __name__ == "__main__":
|
| 5 |
-
uvicorn.run(
|
| 6 |
-
"main:app",
|
| 7 |
-
host="0.0.0.0",
|
| 8 |
-
port=settings.port,
|
| 9 |
-
reload=True, # Enable auto-reload during development
|
| 10 |
-
log_level="info"
|
| 11 |
-
)
|
|
|
|
| 1 |
+
from main import app
|
| 2 |
import uvicorn
|
|
|
|
| 3 |
|
| 4 |
if __name__ == "__main__":
|
| 5 |
+
uvicorn.run("app.main:app", host="0.0.0.0", port=7860, reload=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|