Crcs1225 commited on
Commit
1333c38
·
1 Parent(s): ccabb90

new rag llm

Browse files
Files changed (7) hide show
  1. config.py +14 -9
  2. database.py +80 -109
  3. gemini_service.py +0 -106
  4. main.py +193 -174
  5. models.py +45 -81
  6. rag_system.py +139 -262
  7. run.py +2 -8
config.py CHANGED
@@ -2,15 +2,20 @@ from pydantic_settings import BaseSettings
2
  from typing import Optional
3
 
4
  class Settings(BaseSettings):
5
- mongodb_atlas_uri: str
6
- gemini_api_key: str
7
- database_name: str = "product"
8
- products_collection: str = "marketplace"
9
- conversations_collection: str = "conversations"
10
- embeddings_collection: str = "embeddings"
11
- port: int = 7860
 
 
 
 
12
 
13
  class Config:
14
- env_file = None
 
 
15
 
16
- settings = Settings()
 
2
  from typing import Optional
3
 
4
  class Settings(BaseSettings):
5
+ # MongoDB Atlas
6
+ MONGODB_URI: str
7
+ DATABASE_NAME: str = "marketplace"
8
+ COLLECTION_NAME: str = "product"
9
+
10
+ # Gemini API
11
+ GEMINI_API_KEY: str
12
+
13
+ # Server
14
+ HOST: str = "0.0.0.0"
15
+ PORT: int = 7860
16
 
17
  class Config:
18
+ env_file = ".env"
19
+
20
+ settings = Settings()
21
 
 
database.py CHANGED
@@ -1,117 +1,88 @@
1
- from datetime import datetime
2
  import motor.motor_asyncio
3
  from bson import ObjectId
4
- from typing import List, Optional, Dict, Any
 
5
  from config import settings
6
- import json
7
 
8
- class MongoDB:
9
  def __init__(self):
10
- self.client = motor.motor_asyncio.AsyncIOMotorClient(settings.mongodb_atlas_uri)
11
- self.db = self.client[settings.database_name]
12
- self.products = self.db[settings.products_collection]
13
- self.conversations = self.db[settings.conversations_collection]
14
- self.embeddings = self.db[settings.embeddings_collection]
15
-
16
- async def get_product(self, product_id: str) -> Optional[Dict]:
17
- return await self.products.find_one({"_id": ObjectId(product_id)})
18
-
19
- async def get_products_by_category(self, category: str, limit: int = 10) -> List[Dict]:
20
- # Your database might have different field names for category
21
- cursor = self.products.find({"category": category}).limit(limit)
22
- return await cursor.to_list(length=limit)
23
-
24
- async def search_products(self, query: Dict, limit: int = 10) -> List[Dict]:
25
- cursor = self.products.find(query).limit(limit)
26
- return await cursor.to_list(length=limit)
27
-
28
- async def get_all_products(self, limit: int = 50) -> List[Dict]:
29
- cursor = self.products.find().limit(limit)
30
- return await cursor.to_list(length=limit)
31
-
32
- async def create_product(self, product_data: Dict) -> str:
33
- result = await self.products.insert_one(product_data)
34
- return str(result.inserted_id)
35
-
36
- async def update_product(self, product_id: str, update_data: Dict) -> bool:
37
- result = await self.products.update_one(
38
- {"_id": ObjectId(product_id)},
39
- {"$set": update_data}
40
- )
41
- return result.modified_count > 0
42
-
43
- async def delete_product(self, product_id: str) -> bool:
44
- result = await self.products.delete_one({"_id": ObjectId(product_id)})
45
- return result.deleted_count > 0
46
-
47
- async def get_conversation(self, conversation_id: str) -> Optional[Dict]:
48
- return await self.conversations.find_one({"_id": ObjectId(conversation_id)})
49
-
50
- async def create_conversation(self, user_id: str) -> str:
51
- conversation_data = {
52
- "user_id": user_id,
53
- "messages": [],
54
- "created_at": datetime.now(),
55
- "updated_at": datetime.now()
56
- }
57
- result = await self.conversations.insert_one(conversation_data)
58
- return str(result.inserted_id)
59
-
60
- async def add_message_to_conversation(self, conversation_id: str, message: Dict) -> bool:
61
- result = await self.conversations.update_one(
62
- {"_id": ObjectId(conversation_id)},
63
- {
64
- "$push": {"messages": message},
65
- "$set": {"updated_at": datetime.now()}
66
- }
67
- )
68
- return result.modified_count > 0
69
-
70
- async def get_user_conversations(self, user_id: str, limit: int = 10) -> List[Dict]:
71
- cursor = self.conversations.find({"user_id": user_id}).sort("updated_at", -1).limit(limit)
72
- return await cursor.to_list(length=limit)
73
-
74
- async def store_embedding(self, text: str, embedding: List[float], metadata: Dict) -> str:
75
- doc = {
76
- "text": text,
77
- "embedding": embedding,
78
- "metadata": metadata,
79
- "created_at": datetime.now()
80
- }
81
- result = await self.embeddings.insert_one(doc)
82
- return str(result.inserted_id)
83
-
84
- async def find_similar_embeddings(self, embedding: List[float], limit: int = 5) -> List[Dict]:
85
- # This is a simplified version - in production, you'd use vector search
86
- pipeline = [
87
- {
88
- "$addFields": {
89
- "similarity": {
90
- "$sqrt": {
91
- "$sum": {
92
- "$map": {
93
- "input": {"$range": [0, {"$size": "$embedding"}]},
94
- "as": "idx",
95
- "in": {
96
- "$pow": [
97
- {"$subtract": [
98
- {"$arrayElemAt": ["$embedding", "$$idx"]},
99
- {"$arrayElemAt": [embedding, "$$idx"]}
100
- ]},
101
- 2
102
- ]
103
- }
104
- }
105
- }
106
- }
107
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  }
109
- },
110
- {"$sort": {"similarity": 1}},
111
- {"$limit": limit}
112
- ]
113
- cursor = self.embeddings.aggregate(pipeline)
114
- return await cursor.to_list(length=limit)
 
115
 
116
- # Database instance
117
- db = MongoDB()
 
 
1
  import motor.motor_asyncio
2
  from bson import ObjectId
3
+ from typing import List, Dict, Any
4
+ import numpy as np
5
  from config import settings
 
6
 
7
+ class Database:
8
  def __init__(self):
9
+ self.client = None
10
+ self.db = None
11
+ self.collection = None
12
+
13
+ async def connect(self):
14
+ self.client = motor.motor_asyncio.AsyncIOMotorClient(settings.MONGODB_URI)
15
+ self.db = self.client[settings.DATABASE_NAME]
16
+ self.collection = self.db[settings.COLLECTION_NAME]
17
+
18
+ async def similarity_search(self, query_embedding: List[float], limit: int = 3) -> List[Dict]:
19
+ """Search for similar products using vector similarity"""
20
+ try:
21
+ # First try vector search if index exists
22
+ pipeline = [
23
+ {
24
+ "$vectorSearch": {
25
+ "index": "vector_index", # Your vector index name
26
+ "path": "embedding",
27
+ "queryVector": query_embedding,
28
+ "numCandidates": 100,
29
+ "limit": limit
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  }
31
+ },
32
+ {
33
+ "$project": {
34
+ "_id": 1,
35
+ "title": 1,
36
+ "category": 1,
37
+ "product_description": 1,
38
+ "final_price": 1,
39
+ "score": {"$meta": "vectorSearchScore"}
40
+ }
41
+ }
42
+ ]
43
+
44
+ cursor = self.collection.aggregate(pipeline)
45
+ results = []
46
+ async for doc in cursor:
47
+ results.append({
48
+ "id": str(doc["_id"]),
49
+ "content": f"Product: {doc.get('title', 'N/A')}. Description: {doc.get('product_description', 'N/A')}. Category: {doc.get('category', 'N/A')}. Price: {doc.get('final_price', 'N/A')}.",
50
+ "source": doc.get('title', 'product_database'),
51
+ "metadata": {
52
+ "category": doc.get('category', 'N/A'),
53
+ "price": doc.get('final_price', 'N/A'),
54
+ "similarity_score": doc.get('score', 0)
55
+ }
56
+ })
57
+ return results
58
+ except Exception as e:
59
+ print(f"Vector search failed, falling back to text search: {e}")
60
+ # Fallback to text search if vector search fails
61
+ return await self.search_by_category("tops", limit)
62
+
63
+ async def search_by_category(self, category: str, limit: int = 5) -> List[Dict]:
64
+ """Search products by category (fallback if vector search fails)"""
65
+ cursor = self.collection.find(
66
+ {"category": {"$regex": category, "$options": "i"}}
67
+ ).limit(limit)
68
+
69
+ results = []
70
+ async for doc in cursor:
71
+ results.append({
72
+ "id": str(doc["_id"]),
73
+ "content": f"Product: {doc.get('title', 'N/A')}. Description: {doc.get('product_description', 'N/A')}. Category: {doc.get('category', 'N/A')}. Price: {doc.get('final_price', 'N/A')}.",
74
+ "source": doc.get('title', 'product_database'),
75
+ "metadata": {
76
+ "category": doc.get('category', 'N/A'),
77
+ "price": doc.get('final_price', 'N/A')
78
  }
79
+ })
80
+ return results
81
+
82
+ async def insert_documents(self, documents: List[Dict]) -> List[str]:
83
+ """Insert documents into the collection"""
84
+ result = await self.collection.insert_many(documents)
85
+ return [str(id) for id in result.inserted_ids]
86
 
87
+ # Global database instance
88
+ db = Database()
gemini_service.py DELETED
@@ -1,106 +0,0 @@
1
- import google.generativeai as genai
2
- from typing import List, Dict, Any, Optional
3
- import asyncio
4
- import aiohttp
5
- import json
6
- from config import settings
7
-
8
- class GeminiService:
9
- def __init__(self):
10
- genai.configure(api_key=settings.gemini_api_key)
11
- self.model = genai.GenerativeModel('gemini-2.5-flash')
12
-
13
- async def generate_response(self, prompt: str, context: str = "") -> str:
14
- """Generate response using Gemini API with context"""
15
- try:
16
- full_prompt = f"""
17
- Context Information:
18
- {context}
19
-
20
- User Question: {prompt}
21
-
22
- You are a helpful shopping assistant for Daddy's Shop. Use the context information above to answer the user's question accurately and helpfully. If the context doesn't contain relevant information, use your general knowledge but be honest about limitations.
23
-
24
- Provide a friendly, professional response focused on helping with shopping needs.
25
- """
26
-
27
- # Run in thread pool since Gemini doesn't have native async support
28
- loop = asyncio.get_event_loop()
29
- response = await loop.run_in_executor(
30
- None,
31
- lambda: self.model.generate_content(full_prompt)
32
- )
33
-
34
- return response.text
35
-
36
- except Exception as e:
37
- print(f"Gemini API error: {e}")
38
- return "I apologize, but I'm having trouble processing your request right now. Please try again later."
39
-
40
- async def generate_embedding(self, text: str) -> List[float]:
41
- """Generate embeddings for text using Gemini"""
42
- try:
43
- # Note: Gemini doesn't have direct embedding API, so we'll use a workaround
44
- # For production, consider using SentenceTransformers or another embedding service
45
- embedding_model = genai.GenerativeModel('embedding-001')
46
-
47
- loop = asyncio.get_event_loop()
48
- result = await loop.run_in_executor(
49
- None,
50
- lambda: genai.embed_content(
51
- model=embedding_model,
52
- content=text,
53
- task_type="retrieval_document"
54
- )
55
- )
56
-
57
- return result['embedding']
58
-
59
- except Exception as e:
60
- print(f"Embedding generation error: {e}")
61
- # Fallback to simple embedding (in production, use proper embedding model)
62
- return [0.0] * 384 # Default dimension
63
-
64
- async def classify_intent(self, message: str) -> Dict[str, Any]:
65
- """Classify user intent using Gemini"""
66
- prompt = f"""
67
- Classify the following user message into one of these intents:
68
- - product_inquiry: Questions about products, features, availability
69
- - pricing: Questions about costs, discounts, prices
70
- - shipping: Questions about delivery, shipping costs, timelines
71
- - returns: Questions about returns, refunds, exchanges
72
- - support: General customer support, contact information
73
- - greeting: Hello, hi, greetings
74
- - unknown: Cannot classify
75
-
76
- Message: "{message}"
77
-
78
- Return ONLY a JSON object with:
79
- {{
80
- "intent": "classified_intent",
81
- "confidence": 0.95,
82
- "entities": ["extracted_entities", "if_any"]
83
- }}
84
- """
85
-
86
- try:
87
- loop = asyncio.get_event_loop()
88
- response = await loop.run_in_executor(
89
- None,
90
- lambda: self.model.generate_content(prompt)
91
- )
92
-
93
- # Parse JSON response
94
- import re
95
- json_match = re.search(r'\{.*\}', response.text, re.DOTALL)
96
- if json_match:
97
- return json.loads(json_match.group())
98
- else:
99
- return {"intent": "unknown", "confidence": 0.0, "entities": []}
100
-
101
- except Exception as e:
102
- print(f"Intent classification error: {e}")
103
- return {"intent": "unknown", "confidence": 0.0, "entities": []}
104
-
105
- # Gemini service instance
106
- gemini_service = GeminiService()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,217 +1,236 @@
1
- from fastapi import FastAPI, HTTPException, Depends
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from typing import List, Optional
4
  import uuid
5
- from datetime import datetime
 
 
6
 
7
- from config import settings
8
  from database import db
9
- from rag_system import rag_system
10
- from models import (
11
- Product, ProductCreate, ChatRequest, ChatResponse,
12
- SearchRequest, Conversation, ChatMessage
13
- )
14
 
15
  app = FastAPI(
16
- title="Daddy's Shop RAG Chatbot API",
17
- description="AI-powered shopping assistant with RAG using MongoDB and Gemini",
18
  version="1.0.0"
19
  )
20
 
21
  # CORS middleware
22
  app.add_middleware(
23
  CORSMiddleware,
24
- allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "https://yourdomain.com"],
25
  allow_credentials=True,
26
  allow_methods=["*"],
27
  allow_headers=["*"],
28
  )
 
29
 
30
- @app.get("/")
31
- async def root():
32
- """Health check endpoint"""
33
- return {
34
- "message": "Daddy's Shop RAG Chatbot API is running!",
35
- "version": "1.0.0",
36
- "status": "healthy"
37
- }
38
 
39
- @app.post("/chat", response_model=ChatResponse)
40
- async def chat_endpoint(request: ChatRequest):
41
- """
42
- Main chatbot endpoint with RAG capabilities
43
- """
44
  try:
45
- # Create new conversation if no ID provided
46
- if not request.conversation_id:
47
- conversation_id = await db.create_conversation(request.user_id)
48
- else:
49
- conversation_id = request.conversation_id
50
- # Verify conversation exists
51
- conversation = await db.get_conversation(conversation_id)
52
- if not conversation:
53
- conversation_id = await db.create_conversation(request.user_id)
54
-
55
- # Add user message to conversation
56
- user_message = ChatMessage(
57
- sender="user",
58
- text=request.message
59
- )
60
- await db.add_message_to_conversation(conversation_id, user_message.dict())
61
 
62
- # Get conversation history for context
63
- conversation = await db.get_conversation(conversation_id)
64
- history = conversation.get('messages', []) if conversation else []
65
 
66
- # Generate response using RAG system
67
- rag_result = await rag_system.generate_chat_response(request.message, history)
 
68
 
69
- # Add bot response to conversation
70
- bot_message = ChatMessage(
71
- sender="bot",
72
- text=rag_result["response"]
73
- )
74
- await db.add_message_to_conversation(conversation_id, bot_message.dict())
75
 
76
- return ChatResponse(
77
- response=rag_result["response"],
78
- conversation_id=conversation_id,
79
- suggested_questions=rag_result["suggested_questions"],
80
- relevant_products=rag_result["relevant_products"],
81
- intent=rag_result["intent"],
82
- confidence=rag_result["confidence"]
83
- )
 
 
 
 
84
 
85
  except Exception as e:
86
- raise HTTPException(status_code=500, detail=f"Chat processing error: {str(e)}")
 
87
 
88
- @app.get("/products", response_model=List[Product])
89
- async def get_products(category: Optional[str] = None, limit: int = 20):
90
- """Get products with optional category filter"""
91
  try:
92
- if category:
93
- products = await db.get_products_by_category(category, limit)
94
- else:
95
- products = await db.get_all_products(limit)
96
- return products
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  except Exception as e:
98
- raise HTTPException(status_code=500, detail=f"Error fetching products: {str(e)}")
99
 
100
- @app.get("/products/{product_id}", response_model=Product)
101
- async def get_product(product_id: str):
102
- """Get specific product by ID"""
103
- product = await db.get_product(product_id)
104
- if not product:
105
- raise HTTPException(status_code=404, detail="Product not found")
106
- return product
107
 
108
- @app.post("/products", response_model=dict)
109
- async def create_product(product: ProductCreate):
110
- """Create a new product"""
111
- try:
112
- product_data = product.dict()
113
- product_data["_id"] = str(uuid.uuid4())[:8] # Simple ID generation
114
- product_id = await db.create_product(product_data)
115
- return {"message": "Product created successfully", "product_id": product_id}
116
- except Exception as e:
117
- raise HTTPException(status_code=500, detail=f"Error creating product: {str(e)}")
118
 
119
- @app.post("/search", response_model=List[Product])
120
- async def search_products(request: SearchRequest):
121
- """Semantic search for products"""
122
  try:
123
- products = await rag_system.retrieve_relevant_products(
124
- request.query,
125
- request.category,
126
- request.max_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
127
  )
128
- return products
129
  except Exception as e:
130
- raise HTTPException(status_code=500, detail=f"Search error: {str(e)}")
131
-
132
- @app.get("/conversations/{user_id}", response_model=List[Conversation])
133
- async def get_user_conversations(user_id: str, limit: int = 10):
134
- """Get user's conversation history"""
 
135
  try:
136
- conversations = await db.get_user_conversations(user_id, limit)
137
- return conversations
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  except Exception as e:
139
- raise HTTPException(status_code=500, detail=f"Error fetching conversations: {str(e)}")
140
 
141
- @app.get("/intents")
142
- async def get_available_intents():
143
- """Get information about available intents"""
144
- return {
145
- "intents": [
146
- "product_inquiry", "pricing", "shipping",
147
- "returns", "support", "greeting", "unknown"
148
- ],
149
- "description": "Intent classification for user messages"
150
- }
151
- @app.get("/test-products")
152
- async def test_products(limit: int = 3):
153
- """Test endpoint to see transformed products"""
154
  try:
155
- products = await db.get_all_products(limit)
156
- transformed_products = [rag_system._transform_product_doc(product) for product in products]
 
 
 
 
 
 
 
 
 
 
 
 
157
  return {
158
- "original_count": len(products),
159
- "transformed_count": len(transformed_products),
160
- "original_sample": products[0] if products else {},
161
- "transformed_sample": transformed_products[0] if transformed_products else {},
162
- "all_transformed": transformed_products
163
  }
164
  except Exception as e:
165
- raise HTTPException(status_code=500, detail=f"Test error: {str(e)}")
166
- # Sample data initialization
167
- @app.on_event("startup")
168
- async def startup_event():
169
- """Initialize sample data if needed"""
170
- try:
171
- # Check if we have any products
172
- products = await db.get_all_products(1)
173
- if not products:
174
- await _initialize_sample_data()
175
- except Exception as e:
176
- print(f"Startup initialization error: {e}")
177
-
178
- async def _initialize_sample_data():
179
- """Initialize sample product data"""
180
- sample_products = [
181
- {
182
- "name": "Wireless Bluetooth Earbuds",
183
- "description": "High-quality wireless earbuds with noise cancellation and 24-hour battery life.",
184
- "price": 79.99,
185
- "category": "electronics",
186
- "in_stock": True,
187
- "tags": ["wireless", "bluetooth", "audio", "noise-cancellation"],
188
- "features": ["Noise Cancellation", "24h Battery", "Water Resistant", "Touch Controls"]
189
- },
190
- {
191
- "name": "Smart Fitness Watch",
192
- "description": "Advanced fitness tracker with heart rate monitoring, GPS, and smartphone connectivity.",
193
- "price": 199.99,
194
- "category": "electronics",
195
- "in_stock": True,
196
- "tags": ["fitness", "smartwatch", "health", "tracking"],
197
- "features": ["Heart Rate Monitor", "GPS", "Sleep Tracking", "Waterproof"]
198
- },
199
- {
200
- "name": "Organic Cotton T-Shirt",
201
- "description": "Comfortable and sustainable organic cotton t-shirt available in multiple colors.",
202
- "price": 24.99,
203
- "category": "clothing",
204
- "in_stock": True,
205
- "tags": ["cotton", "organic", "sustainable", "casual"],
206
- "features": ["Organic Cotton", "Machine Washable", "Multiple Colors"]
207
- }
208
- ]
209
-
210
- for product in sample_products:
211
- await db.create_product(product)
212
 
213
- print("Sample data initialized successfully")
214
-
215
  if __name__ == "__main__":
216
- import uvicorn
217
- uvicorn.run(app, host="0.0.0.0", port=settings.port)
 
1
+ import asyncio
2
+ from fastapi import FastAPI, HTTPException, Query
3
  from fastapi.middleware.cors import CORSMiddleware
4
+ from fastapi.responses import JSONResponse
5
  import uuid
6
+ import uvicorn
7
+ from typing import List, Optional
8
+ import traceback
9
 
10
+ # Import your existing modules
11
  from database import db
12
+ from models import ChatMessage, ChatRequest, ChatResponse, Product, SearchRequest, Conversation, KnowledgeDocument, Document, SourceInfo
13
+ from config import settings
14
+ from rag_system import rag_pipeline
 
 
15
 
16
  app = FastAPI(
17
+ title="RAG Chatbot API",
18
+ description="Lightweight RAG Chatbot using MongoDB Atlas and Gemini",
19
  version="1.0.0"
20
  )
21
 
22
  # CORS middleware
23
  app.add_middleware(
24
  CORSMiddleware,
25
+ allow_origins=["*"],
26
  allow_credentials=True,
27
  allow_methods=["*"],
28
  allow_headers=["*"],
29
  )
30
+ embeddings_generated = False
31
 
 
 
 
 
 
 
 
 
32
 
33
+ @app.on_event("startup")
34
+ async def startup_event():
35
+ """Run on application startup"""
36
+ global embeddings_generated
 
37
  try:
38
+ print("🚀 Starting RAG Chatbot API...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
+ # Initialize database connection
41
+ await db.connect()
 
42
 
43
+ # Check if we need to generate embeddings
44
+ total_docs = await db.collection.count_documents({})
45
+ docs_with_embeddings = await db.collection.count_documents({"embedding": {"$exists": True}})
46
 
47
+ print(f"📊 Database status: {total_docs} total documents, {docs_with_embeddings} with embeddings")
 
 
 
 
 
48
 
49
+ # If we have documents but no embeddings, generate them
50
+ if total_docs > 0 and docs_with_embeddings == 0:
51
+ print("🔄 No embeddings found. Starting automatic embedding generation...")
52
+ await generate_embeddings_on_startup()
53
+ embeddings_generated = True
54
+ elif docs_with_embeddings > 0:
55
+ print(f"✅ Embeddings already exist for {docs_with_embeddings} documents")
56
+ embeddings_generated = True
57
+ else:
58
+ print("ℹ️ No documents found in database")
59
+
60
+ print("✅ RAG Chatbot API is ready!")
61
 
62
  except Exception as e:
63
+ print(f" Startup error: {e}")
64
+ raise
65
 
66
+ async def generate_embeddings_on_startup():
67
+ """Generate embeddings for all documents on startup"""
 
68
  try:
69
+ # Find all documents without embeddings
70
+ cursor = db.collection.find({"embedding": {"$exists": False}})
71
+ documents_without_embeddings = []
72
+ async for doc in cursor:
73
+ documents_without_embeddings.append(doc)
74
+
75
+ if not documents_without_embeddings:
76
+ print("✅ All documents already have embeddings")
77
+ return
78
+
79
+ print(f"🔄 Generating embeddings for {len(documents_without_embeddings)} documents...")
80
+
81
+ updated_count = 0
82
+ errors = 0
83
+
84
+ # Process in smaller batches to avoid timeout
85
+ batch_size = 50
86
+ for i in range(0, len(documents_without_embeddings), batch_size):
87
+ batch = documents_without_embeddings[i:i + batch_size]
88
+
89
+ for doc in batch:
90
+ try:
91
+ # Create meaningful content for embedding
92
+ content_parts = []
93
+
94
+ # Include all relevant text fields
95
+ if doc.get('title'):
96
+ content_parts.append(f"Product: {doc['title']}")
97
+ if doc.get('product_description'):
98
+ content_parts.append(f"Description: {doc['product_description']}")
99
+ if doc.get('category'):
100
+ content_parts.append(f"Category: {doc['category']}")
101
+
102
+ content = ". ".join(content_parts)
103
+
104
+ if content.strip():
105
+ # Generate embedding
106
+ embedding = await rag_pipeline.get_embeddings([content])
107
+
108
+ # Update document with embedding
109
+ await db.collection.update_one(
110
+ {"_id": doc["_id"]},
111
+ {"$set": {"embedding": embedding[0]}}
112
+ )
113
+ updated_count += 1
114
+
115
+ # Progress update every 50 documents
116
+ if updated_count % 50 == 0:
117
+ print(f"✅ Processed {updated_count}/{len(documents_without_embeddings)} documents...")
118
+
119
+ except Exception as e:
120
+ errors += 1
121
+ print(f"❌ Error processing document: {e}")
122
+ continue
123
+
124
+ # Small delay between batches
125
+ await asyncio.sleep(1)
126
+
127
+ # Final status
128
+ final_with_embeddings = await db.collection.count_documents({"embedding": {"$exists": True}})
129
+ print(f"🎉 Embedding generation completed! {final_with_embeddings} documents now have embeddings")
130
+
131
  except Exception as e:
132
+ print(f" Embedding generation failed: {e}")
133
 
134
+ @app.get("/")
135
+ async def root():
136
+ return {"message": "RAG Chatbot API is running!", "status": "healthy"}
 
 
 
 
137
 
138
+ @app.get("/health")
139
+ async def health_check():
140
+ return {"status": "healthy", "service": "rag-chatbot"}
 
 
 
 
 
 
 
141
 
142
+ @app.post("/chat")
143
+ async def chat_with_assistant(request: ChatRequest):
144
+ """Main chat endpoint for product queries"""
145
  try:
146
+ print(f"💬 Received chat request: {request.message}")
147
+ response, sources = await rag_pipeline.generate_response(request.message)
148
+
149
+ suggested_questions = rag_pipeline.generate_followup_questions(
150
+ request.message,
151
+ sources
152
+ )
153
+
154
+ # Convert to SourceInfo objects
155
+ source_objects = []
156
+ for product in sources:
157
+ source_objects.append(SourceInfo(
158
+ id=product.get("id", ""),
159
+ name=product.get("source", "Product"),
160
+ category=product.get("metadata", {}).get("category", "N/A"),
161
+ price=str(product.get("metadata", {}).get("price", "N/A")),
162
+ similarity_score=product.get("metadata", {}).get("similarity_score", 0)
163
+ ))
164
+
165
+ return ChatResponse(
166
+ response=response,
167
+ sources=source_objects,
168
+ suggested_questions=suggested_questions,
169
+ conversation_id=request.conversation_id
170
  )
 
171
  except Exception as e:
172
+ print(f" Error in /chat endpoint: {traceback.format_exc()}")
173
+ raise HTTPException(status_code=500, detail=f"Error processing request: {str(e)}")
174
+
175
+ @app.get("/debug/vector-results")
176
+ async def debug_vector_results(query: str = "tops"):
177
+ """See exactly what vector search returns"""
178
  try:
179
+ # Get embeddings for the query
180
+ query_embedding = await rag_pipeline.get_embeddings([query])
181
+ print(f"🔍 Testing vector search for: '{query}'")
182
+ print(f"📐 Embedding dimensions: {len(query_embedding[0])}")
183
+
184
+ # Perform vector search
185
+ results = await db.similarity_search(query_embedding[0], limit=5)
186
+
187
+ response_data = {
188
+ "query": query,
189
+ "embedding_dimensions": len(query_embedding[0]),
190
+ "results_found": len(results),
191
+ "raw_results": []
192
+ }
193
+
194
+ for i, doc in enumerate(results):
195
+ response_data["raw_results"].append({
196
+ "rank": i + 1,
197
+ "id": doc["id"],
198
+ "content": doc["content"],
199
+ "source": doc.get("source", "unknown"),
200
+ "metadata": doc.get("metadata", {})
201
+ })
202
+ print(f"📄 Result {i+1}: {doc['content'][:100]}...")
203
+
204
+ return response_data
205
+
206
  except Exception as e:
207
+ return {"error": str(e), "traceback": traceback.format_exc()}
208
 
209
+ @app.get("/debug/sample-products")
210
+ async def debug_sample_products(category: str = "tops", limit: int = 5):
211
+ """Get sample products to see what content is available"""
 
 
 
 
 
 
 
 
 
 
212
  try:
213
+ cursor = db.collection.find({"category": {"$regex": category, "$options": "i"}}).limit(limit)
214
+ products = []
215
+ async for doc in cursor:
216
+ product_info = {
217
+ "id": str(doc["_id"]),
218
+ "name": doc.get("title", "N/A"),
219
+ "category": doc.get("category", "N/A"),
220
+ "description": doc.get("product_description", "N/A"),
221
+ "price": doc.get("final_price", "N/A"),
222
+ "has_embedding": "embedding" in doc,
223
+ "content_used_for_embedding": f"{doc.get('title', '')} {doc.get('product_description', '')} {doc.get('category', '')}"
224
+ }
225
+ products.append(product_info)
226
+
227
  return {
228
+ "category": category,
229
+ "products_found": len(products),
230
+ "products": products
 
 
231
  }
232
  except Exception as e:
233
+ return {"error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
234
 
 
 
235
  if __name__ == "__main__":
236
+ uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)
 
models.py CHANGED
@@ -1,100 +1,64 @@
1
- from pydantic import BaseModel, Field, validator
2
  from typing import List, Optional, Dict, Any
3
  from datetime import datetime
4
- from enum import Enum
5
 
6
- class ProductCategory(str, Enum):
7
- ELECTRONICS = "electronics"
8
- CLOTHING = "clothing"
9
- HOME = "home"
10
- BEAUTY = "beauty"
11
- SPORTS = "sports"
12
- BOOKS = "books"
13
- OTHER = "other"
14
-
15
- class ProductCreate(BaseModel):
16
- name: str
17
- description: str = ""
18
- price: float = 0.0
19
- category: ProductCategory = ProductCategory.OTHER
20
- in_stock: bool = True
21
- tags: List[str] = []
22
- features: List[str] = []
23
- image_url: Optional[str] = None
24
-
25
- @validator('price', pre=True)
26
- def validate_price(cls, v):
27
- if v is None:
28
- return 0.0
29
- try:
30
- return float(v)
31
- except (TypeError, ValueError):
32
- return 0.0
33
-
34
- @validator('category', pre=True)
35
- def validate_category(cls, v):
36
- if isinstance(v, ProductCategory):
37
- return v
38
- try:
39
- return ProductCategory(v)
40
- except ValueError:
41
- return ProductCategory.OTHER
42
-
43
- class Product(BaseModel):
44
  id: str
45
  name: str
46
- description: str = ""
47
- price: float = 0.0
48
- category: ProductCategory = ProductCategory.OTHER
49
- in_stock: bool = True
50
- tags: List[str] = []
51
- features: List[str] = []
52
- image_url: Optional[str] = None
53
-
54
- @validator('price', pre=True)
55
- def validate_price(cls, v):
56
- if v is None:
57
- return 0.0
58
- try:
59
- return float(v)
60
- except (TypeError, ValueError):
61
- return 0.0
62
-
63
- @validator('category', pre=True)
64
- def validate_category(cls, v):
65
- if isinstance(v, ProductCategory):
66
- return v
67
- try:
68
- return ProductCategory(v)
69
- except ValueError:
70
- return ProductCategory.OTHER
71
-
72
- class ChatMessage(BaseModel):
73
- sender: str
74
- text: str
75
- timestamp: datetime = Field(default_factory=datetime.now)
76
 
 
77
  class ChatRequest(BaseModel):
78
  message: str
79
  conversation_id: Optional[str] = None
80
- user_id: Optional[str] = "anonymous"
81
 
82
  class ChatResponse(BaseModel):
83
  response: str
84
- conversation_id: str
85
- suggested_questions: List[str] = []
86
- relevant_products: List[Product] = []
87
- intent: str
88
- confidence: float
 
 
 
 
 
 
 
 
89
 
90
  class SearchRequest(BaseModel):
91
  query: str
92
- category: Optional[ProductCategory] = None
93
- max_results: int = 5
 
 
 
 
 
 
94
 
95
  class Conversation(BaseModel):
96
  id: str
97
  user_id: str
98
- messages: List[ChatMessage] = []
99
- created_at: datetime = Field(default_factory=datetime.now)
100
- updated_at: datetime = Field(default_factory=datetime.now)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field
2
  from typing import List, Optional, Dict, Any
3
  from datetime import datetime
 
4
 
5
+ # Source information model
6
+ class SourceInfo(BaseModel):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  id: str
8
  name: str
9
+ category: str
10
+ price: str
11
+ similarity_score: float
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ # Chat request and response models
14
  class ChatRequest(BaseModel):
15
  message: str
16
  conversation_id: Optional[str] = None
 
17
 
18
  class ChatResponse(BaseModel):
19
  response: str
20
+ sources: List[SourceInfo] # Changed from List[str] to List[SourceInfo]
21
+ suggested_questions: List[str]
22
+ conversation_id: Optional[str] = None # Make this optional
23
+
24
+ # Product models
25
+ class Product(BaseModel):
26
+ id: str
27
+ name: str
28
+ category: str
29
+ description: str
30
+ price: float
31
+ image_url: Optional[str] = None
32
+ tags: List[str] = []
33
 
34
  class SearchRequest(BaseModel):
35
  query: str
36
+ category: Optional[str] = None
37
+ limit: int = 20
38
+
39
+ # Conversation models
40
+ class ChatMessage(BaseModel):
41
+ role: str # "user" or "assistant"
42
+ content: str
43
+ timestamp: datetime
44
 
45
  class Conversation(BaseModel):
46
  id: str
47
  user_id: str
48
+ messages: List[ChatMessage]
49
+ created_at: datetime
50
+ updated_at: datetime
51
+
52
+ # Knowledge base models
53
+ class Document(BaseModel):
54
+ content: str
55
+ metadata: Dict[str, Any] = {}
56
+ source: str = "upload"
57
+
58
+ class KnowledgeDocument(BaseModel):
59
+ id: str
60
+ content: str
61
+ embedding: List[float]
62
+ metadata: Dict[str, Any]
63
+ source: str
64
+ created_at: datetime
rag_system.py CHANGED
@@ -1,285 +1,162 @@
1
- from typing import List, Dict, Any, Optional
2
- from database import db
3
- from gemini_service import gemini_service
4
- from models import Product, ProductCategory
5
- import numpy as np
6
  from sentence_transformers import SentenceTransformer
 
7
  import asyncio
 
 
8
 
9
- class RAGSystem:
10
  def __init__(self):
 
11
  self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- def _transform_product_doc(self, product_doc: Dict) -> Dict:
14
- """Transform MongoDB document to match Product model using actual field names"""
15
  try:
16
- # Extract fields from your actual database schema
17
- product_id = str(product_doc.get('_id', '')) or str(product_doc.get('product_id', ''))
 
 
 
18
 
19
- # Map to our expected fields
20
- transformed = {
21
- "id": product_id,
22
- "name": product_doc.get('title', 'Unnamed Product'),
23
- "description": product_doc.get('product_description', 'No description available'),
24
- "price": float(product_doc.get('final_price', 0.0)),
25
- "category": self._map_category(product_doc.get('category', 'other')),
26
- "in_stock": True, # Default to True since we don't have this field
27
- "tags": self._extract_tags(product_doc),
28
- "features": self._extract_features(product_doc),
29
- "image_url": self._get_first_image(product_doc.get('images', '')),
30
- }
31
-
32
- return transformed
33
 
 
34
  except Exception as e:
35
- print(f"Error transforming product {product_doc.get('_id')}: {e}")
36
- # Return a safe default product
37
- return {
38
- "id": str(product_doc.get('_id', 'unknown')),
39
- "name": product_doc.get('title', 'Unnamed Product'),
40
- "description": "Product information unavailable",
41
- "price": 0.0,
42
- "category": ProductCategory.OTHER,
43
- "in_stock": True,
44
- "tags": [],
45
- "features": [],
46
- }
47
-
48
- def _map_category(self, raw_category: str) -> ProductCategory:
49
- """Map raw category string to ProductCategory enum"""
50
- if not raw_category:
51
- return ProductCategory.OTHER
52
-
53
- category_lower = raw_category.lower()
54
-
55
- # Map based on your actual category values
56
- if any(keyword in category_lower for keyword in ['electronic', 'tech', 'computer', 'phone']):
57
- return ProductCategory.ELECTRONICS
58
- elif any(keyword in category_lower for keyword in ['cloth', 'fashion', 'wear', 'top', 'dress', 'shirt', 'jeans']):
59
- return ProductCategory.CLOTHING
60
- elif any(keyword in category_lower for keyword in ['home', 'garden', 'furniture', 'decor']):
61
- return ProductCategory.HOME
62
- elif any(keyword in category_lower for keyword in ['beauty', 'cosmetic', 'skin', 'hair', 'cream', 'mask', 'makeup']):
63
- return ProductCategory.BEAUTY
64
- elif any(keyword in category_lower for keyword in ['sport', 'fitness', 'exercise', 'gym']):
65
- return ProductCategory.SPORTS
66
- elif any(keyword in category_lower for keyword in ['book', 'literature']):
67
- return ProductCategory.BOOKS
68
- else:
69
- return ProductCategory.OTHER
70
 
71
- def _extract_tags(self, product_doc: Dict) -> List[str]:
72
- """Extract tags from product details"""
73
- tags = []
74
-
75
- # Add category as a tag
76
- category = product_doc.get('category')
77
- if category:
78
- tags.append(category)
 
 
 
 
79
 
80
- # Extract from product details if available
81
- details = product_doc.get('product_details', '')
82
- if 'Cotton' in details:
83
- tags.append('Cotton')
84
- if 'Polyester' in details:
85
- tags.append('Polyester')
86
 
87
- return tags
88
 
89
- def _extract_features(self, product_doc: Dict) -> List[str]:
90
- """Extract features from product details"""
91
- features = []
92
-
93
- # Extract key features from product details
94
- details = product_doc.get('product_details', '')
95
-
96
- # Look for common features
97
- feature_keywords = [
98
- 'Machine wash', 'Hand wash', 'Dry clean', 'Cotton',
99
- 'Polyester', 'Elastane', 'Spaghetti', 'Sleeveless',
100
- 'Solid pattern', 'Sweetheart neck'
101
- ]
102
-
103
- for keyword in feature_keywords:
104
- if keyword in details:
105
- features.append(keyword)
106
-
107
- # Add rating as a feature
108
- rating = product_doc.get('rating')
109
- if rating:
110
- features.append(f'{rating}★ Rating')
111
-
112
- # Add delivery options
113
- delivery = product_doc.get('delivery_options', '')
114
- if 'Pay on delivery' in delivery:
115
- features.append('Pay on Delivery')
116
- if 'Easy returns' in delivery:
117
- features.append('Easy Returns')
118
-
119
- return features
120
 
121
- def _get_first_image(self, images_str: str) -> Optional[str]:
122
- """Extract first image URL from comma-separated string"""
123
- if not images_str:
124
- return None
125
- images = images_str.split(',')
126
- return images[0].strip() if images else None
127
-
128
- # ... rest of your existing methods remain the same
129
- async def retrieve_relevant_products(self, query: str, category: Optional[str] = None, limit: int = 5) -> List[Dict]:
130
- """Retrieve and transform relevant products"""
131
  try:
132
- # Generate query embedding
133
- query_embedding = await self._get_embedding(query)
134
-
135
- # Get all products or filtered by category
136
- if category:
137
- products = await db.get_products_by_category(category, limit=50)
138
- else:
139
- products = await db.get_all_products(limit=100)
140
 
141
- # Transform products first
142
- transformed_products = [self._transform_product_doc(product) for product in products]
143
 
144
- # Calculate similarity scores on transformed products
145
- scored_products = []
146
- for product in transformed_products:
147
- product_text = f"{product.get('name', '')} {product.get('description', '')} {' '.join(product.get('tags', []))}"
148
- product_embedding = await self._get_embedding(product_text)
149
-
150
- similarity = self._cosine_similarity(query_embedding, product_embedding)
151
- scored_products.append((product, similarity))
152
 
153
- # Sort by similarity and return top results
154
- scored_products.sort(key=lambda x: x[1], reverse=True)
155
- return [product for product, score in scored_products[:limit]]
156
 
157
  except Exception as e:
158
- print(f"Product retrieval error: {e}")
159
- # Fallback to simple keyword search with transformation
160
- return await self._keyword_search_products(query, category, limit)
161
 
162
- async def _keyword_search_products(self, query: str, category: Optional[str], limit: int) -> List[Dict]:
163
- """Fallback keyword-based product search with transformation"""
164
- search_terms = query.lower().split()
165
-
166
- if category:
167
- products = await db.get_products_by_category(category, limit=50)
168
- else:
169
- products = await db.get_all_products(limit=100)
170
-
171
- # Transform all products first
172
- transformed_products = [self._transform_product_doc(product) for product in products]
173
-
174
- scored_products = []
175
- for product in transformed_products:
176
- score = 0
177
- product_text = f"{product.get('name', '')} {product.get('description', '')} {' '.join(product.get('tags', []))}".lower()
178
-
179
- for term in search_terms:
180
- if term in product_text:
181
- score += 1
182
- if term in product.get('name', '').lower():
183
- score += 2 # Higher weight for name matches
184
-
185
- if score > 0:
186
- scored_products.append((product, score))
187
-
188
- scored_products.sort(key=lambda x: x[1], reverse=True)
189
- return [product for product, score in scored_products[:limit]]
190
-
191
- async def _get_embedding(self, text: str) -> List[float]:
192
- """Get embedding for text using available methods"""
193
- try:
194
- embedding = await gemini_service.generate_embedding(text)
195
- if embedding and len(embedding) > 10:
196
- return embedding
197
- except:
198
- pass
199
-
200
- embedding = self.embedding_model.encode(text)
201
- return embedding.tolist()
202
-
203
- def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
204
- """Calculate cosine similarity between two vectors"""
205
- vec1 = np.array(vec1)
206
- vec2 = np.array(vec2)
207
-
208
- dot_product = np.dot(vec1, vec2)
209
- norm1 = np.linalg.norm(vec1)
210
- norm2 = np.linalg.norm(vec2)
211
-
212
- if norm1 == 0 or norm2 == 0:
213
- return 0.0
214
-
215
- return dot_product / (norm1 * norm2)
216
-
217
- async def build_context(self, query: str, relevant_products: List[Dict]) -> str:
218
- """Build context string from relevant products for the LLM"""
219
- if not relevant_products:
220
- return "No specific product information available. Use general knowledge about e-commerce and shopping."
221
-
222
- context_parts = ["Relevant Products Information:"]
223
-
224
- for i, product in enumerate(relevant_products, 1):
225
- context_parts.append(f"""
226
- Product {i}:
227
- - Name: {product.get('name', 'N/A')}
228
- - Description: {product.get('description', 'N/A')}
229
- - Price: ${product.get('price', 'N/A')}
230
- - Category: {product.get('category', 'N/A')}
231
- - In Stock: {'Yes' if product.get('in_stock') else 'No'}
232
- - Features: {', '.join(product.get('features', []))}
233
- - Tags: {', '.join(product.get('tags', []))}
234
- """)
235
-
236
- return "\n".join(context_parts)
237
-
238
- async def generate_chat_response(self, user_message: str, conversation_history: List[Dict] = None) -> Dict[str, Any]:
239
- """Generate response using RAG pipeline"""
240
- # Classify intent
241
- intent_result = await gemini_service.classify_intent(user_message)
242
-
243
- # Retrieve relevant products (already transformed)
244
- relevant_products = await self.retrieve_relevant_products(user_message, limit=3)
245
-
246
- # Build context
247
- context = await self.build_context(user_message, relevant_products)
248
-
249
- # Generate response using Gemini with context
250
- response = await gemini_service.generate_response(user_message, context)
251
-
252
- # Generate suggested questions based on intent
253
- suggested_questions = await self._generate_suggested_questions(intent_result['intent'], relevant_products)
254
-
255
- return {
256
- "response": response,
257
- "relevant_products": relevant_products,
258
- "suggested_questions": suggested_questions,
259
- "intent": intent_result['intent'],
260
- "confidence": intent_result['confidence']
261
- }
262
-
263
- async def _generate_suggested_questions(self, intent: str, products: List[Dict]) -> List[str]:
264
- """Generate context-aware suggested questions"""
265
- base_questions = {
266
- "product_inquiry": ["Show me more products", "What are your best sellers?", "Any current deals?"],
267
- "pricing": ["Do you offer discounts?", "What's the return policy?", "Any bundle deals?"],
268
- "shipping": ["How long does delivery take?", "Do you ship internationally?", "What are shipping costs?"],
269
- "returns": ["How do I return an item?", "What's your warranty policy?", "Do you offer exchanges?"],
270
- "support": ["Contact customer service", "Store locations", "Business hours"],
271
- "default": ["Best sellers", "Current deals", "Shipping info", "Return policy"]
272
- }
273
-
274
- questions = base_questions.get(intent, base_questions["default"])
275
-
276
- # Add product-specific questions if we have products
277
- if products and intent == "product_inquiry":
278
- categories = list(set(str(p.get('category')) for p in products))
279
- if categories:
280
- questions = [f"Show me more {cat} products" for cat in categories[:2]] + questions
281
-
282
- return questions[:4]
283
 
284
- # RAG system instance
285
- rag_system = RAGSystem()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
 
 
 
 
2
  from sentence_transformers import SentenceTransformer
3
+ from typing import List, Tuple, Dict, Any
4
  import asyncio
5
+ from database import db
6
+ from config import settings
7
 
8
+ class ProductRAGPipeline:
9
  def __init__(self):
10
+ # Initialize embedding model
11
  self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
12
+
13
+ # Initialize Gemini
14
+ genai.configure(api_key=settings.GEMINI_API_KEY)
15
+ self.gemini_model = genai.GenerativeModel('gemini-2.5-flash')
16
+
17
+ # Enhanced personality for shopping assistant
18
+ self.personality_traits = """
19
+ You are a friendly, knowledgeable shopping assistant for a fashion e-commerce store. Your personality traits:
20
+ - Warm, approachable, and enthusiastic about fashion
21
+ - Helpful and patient with customer queries
22
+ - Knowledgeable about products, styles, and fashion trends
23
+ - Casual but professional tone, like a friendly store assistant
24
+ - Use emojis occasionally to express emotion (but don't overdo it)
25
+ - Ask follow-up questions to better understand customer needs
26
+ - Be concise but thorough in product recommendations
27
+ - Always mention key product features, price, and benefits
28
+ - If you suggest multiple products, compare them briefly
29
+ """
30
+
31
+ async def get_embeddings(self, texts: List[str]) -> List[List[float]]:
32
+ """Get embeddings for texts (async wrapper)"""
33
+ loop = asyncio.get_event_loop()
34
+ embeddings = await loop.run_in_executor(
35
+ None, self.embedding_model.encode, texts
36
+ )
37
+ return embeddings.tolist()
38
 
39
+ async def retrieve_relevant_products(self, query: str, limit: int = 3) -> List[Dict]:
40
+ """Retrieve relevant products using vector search with fallback"""
41
  try:
42
+ # Try vector search first
43
+ query_embedding = await self.get_embeddings([query])
44
+ print(f"🔍 Performing vector search with embedding dim: {len(query_embedding[0])}")
45
+ relevant_docs = await db.similarity_search(query_embedding[0], limit=limit)
46
+ print(f"✅ Vector search returned {len(relevant_docs)} results")
47
 
48
+ if not relevant_docs:
49
+ print("🔄 No results from vector search, trying category-based search")
50
+ # Fallback to category-based search
51
+ category_keywords = self._extract_category_from_query(query)
52
+ if category_keywords:
53
+ relevant_docs = await db.search_by_category(category_keywords[0], limit=limit)
54
+ print(f" Category search returned {len(relevant_docs)} results")
 
 
 
 
 
 
 
55
 
56
+ return relevant_docs
57
  except Exception as e:
58
+ print(f"Error in vector search: {e}")
59
+ # Final fallback to generic product search
60
+ return await db.search_by_category("tops", limit=limit)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
+ def _extract_category_from_query(self, query: str) -> List[str]:
63
+ """Extract potential category keywords from user query"""
64
+ query_lower = query.lower()
65
+ categories = []
66
+
67
+ category_mapping = {
68
+ 'tops': ['top', 'shirt', 'blouse', 't-shirt', 'tshirt', 'crop top', 'spaghetti'],
69
+ 'bottoms': ['pant', 'jeans', 'trouser', 'leggings', 'skirt', 'short'],
70
+ 'dresses': ['dress', 'gown', 'frock'],
71
+ 'outerwear': ['jacket', 'sweater', 'hoodie', 'cardigan', 'coat'],
72
+ 'accessories': ['bag', 'jewelry', 'scarf', 'hat', 'belt']
73
+ }
74
 
75
+ for category, keywords in category_mapping.items():
76
+ if any(keyword in query_lower for keyword in keywords):
77
+ categories.append(category)
 
 
 
78
 
79
+ return categories if categories else ['tops'] # Default to tops
80
 
81
+ def create_product_prompt(self, query: str, products: List[Dict]) -> str:
82
+ """Create context-aware prompt with product information"""
83
+ if products:
84
+ context = "AVAILABLE PRODUCTS:\n"
85
+ for i, product in enumerate(products, 1):
86
+ context += f"{i}. {product['content']}\n"
87
+ else:
88
+ context = "No specific product information available at the moment."
89
+
90
+ prompt = f"""
91
+ {self.personality_traits}
92
+
93
+ {context}
94
+
95
+ USER QUESTION: {query}
96
+
97
+ INSTRUCTIONS:
98
+ 1. Answer based primarily on the provided product information
99
+ 2. If suggesting products, mention:
100
+ - Key features and benefits
101
+ - Price (if available)
102
+ - Why it might suit the user's needs
103
+ 3. Be conversational and helpful
104
+ 4. If the exact answer isn't in the products, use your general knowledge but be honest about limitations
105
+ 5. Keep responses concise but complete (2-4 sentences usually)
106
+ 6. Always maintain a friendly, shopping assistant tone
107
+ 7. If multiple products are relevant, compare them briefly
108
+
109
+ SHOPPING ASSISTANT RESPONSE:
110
+ """
111
+ return prompt
112
 
113
+ async def generate_response(self, query: str) -> Tuple[str, List[Dict]]:
114
+ """Generate response using product RAG pipeline"""
 
 
 
 
 
 
 
 
115
  try:
116
+ # Retrieve relevant products
117
+ relevant_products = await self.retrieve_relevant_products(query)
118
+ print(f"📦 Retrieved {len(relevant_products)} relevant products")
 
 
 
 
 
119
 
120
+ # Create context-aware prompt
121
+ prompt = self.create_product_prompt(query, relevant_products)
122
 
123
+ # Generate response using Gemini
124
+ response = self.gemini_model.generate_content(prompt)
125
+ response_text = response.text.strip()
 
 
 
 
 
126
 
127
+ print(f"🤖 Generated response: {response_text[:100]}...")
128
+ return response_text, relevant_products
 
129
 
130
  except Exception as e:
131
+ print(f" Error generating response: {e}")
132
+ fallback_msg = "I apologize, but I'm having trouble accessing our product information right now. Please try again in a moment or contact our customer service for immediate assistance. 😊"
133
+ return fallback_msg, []
134
 
135
+ def generate_followup_questions(self, query: str, products: List[Dict]) -> List[str]:
136
+ """Generate context-aware follow-up questions"""
137
+ base_questions = [
138
+ "Tell me more about this product",
139
+ "What are the alternatives in different colors?",
140
+ "Do you have similar items in different price ranges?",
141
+ "What's the sizing like for these products?",
142
+ "Are any of these currently on sale?"
143
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
+ # Context-aware questions
146
+ query_lower = query.lower()
147
+ if any(word in query_lower for word in ['price', 'cost', 'expensive', 'cheap']):
148
+ base_questions.extend([
149
+ "What's the price range for similar items?",
150
+ "Are there any ongoing discounts?"
151
+ ])
152
+
153
+ if any(word in query_lower for word in ['color', 'colour', 'pattern']):
154
+ base_questions.extend([
155
+ "What other colors are available?",
156
+ "Do you have this in solid colors vs patterns?"
157
+ ])
158
+
159
+ return base_questions[:5] # Return top 5 questions
160
+
161
+ # Global RAG pipeline instance
162
+ rag_pipeline = ProductRAGPipeline()
run.py CHANGED
@@ -1,11 +1,5 @@
 
1
  import uvicorn
2
- from config import settings
3
 
4
  if __name__ == "__main__":
5
- uvicorn.run(
6
- "main:app",
7
- host="0.0.0.0",
8
- port=settings.port,
9
- reload=True, # Enable auto-reload during development
10
- log_level="info"
11
- )
 
1
+ from main import app
2
  import uvicorn
 
3
 
4
  if __name__ == "__main__":
5
+ uvicorn.run("app.main:app", host="0.0.0.0", port=7860, reload=True)