Crcs1225 commited on
Commit
1bc6a5b
·
0 Parent(s):

Initial commit of RAG FastAPI project

Browse files
Files changed (10) hide show
  1. .gitignore +10 -0
  2. Dockerfile +21 -0
  3. config.py +16 -0
  4. database.py +117 -0
  5. gemini_service.py +106 -0
  6. main.py +217 -0
  7. models.py +100 -0
  8. rag_system.py +285 -0
  9. requirements.txt +0 -0
  10. run.py +11 -0
.gitignore ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ env
3
+ # Environment variables file
4
+ .env.local
5
+ .env.*.local
6
+
7
+ __pycache__
8
+ *.pyc
9
+ *.pyo
10
+ venv
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use a slim Python base
2
+ FROM python:3.11-slim
3
+
4
+ # Set working directory
5
+ WORKDIR /code
6
+
7
+ # Install system deps (if you need Mongo client, etc.)
8
+ RUN apt-get update && apt-get install -y build-essential
9
+
10
+ # Copy requirements and install
11
+ COPY requirements.txt .
12
+ RUN pip install --no-cache-dir -r requirements.txt
13
+
14
+ # Copy app code
15
+ COPY ./app /code/app
16
+
17
+ # Expose port
18
+ EXPOSE 7860
19
+
20
+ # Run FastAPI with uvicorn
21
+ CMD ["uvicorn", "run:app", "--host", "0.0.0.0", "--port", "7860"]
config.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic_settings import BaseSettings
2
+ from typing import Optional
3
+
4
+ class Settings(BaseSettings):
5
+ mongodb_atlas_uri: str
6
+ gemini_api_key: str
7
+ database_name: str = "product"
8
+ products_collection: str = "marketplace"
9
+ conversations_collection: str = "conversations"
10
+ embeddings_collection: str = "embeddings"
11
+ port: int = 8000
12
+
13
+ class Config:
14
+ env_file = ".env"
15
+
16
+ settings = Settings()
database.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datetime import datetime
2
+ import motor.motor_asyncio
3
+ from bson import ObjectId
4
+ from typing import List, Optional, Dict, Any
5
+ from config import settings
6
+ import json
7
+
8
+ class MongoDB:
9
+ def __init__(self):
10
+ self.client = motor.motor_asyncio.AsyncIOMotorClient(settings.mongodb_atlas_uri)
11
+ self.db = self.client[settings.database_name]
12
+ self.products = self.db[settings.products_collection]
13
+ self.conversations = self.db[settings.conversations_collection]
14
+ self.embeddings = self.db[settings.embeddings_collection]
15
+
16
+ async def get_product(self, product_id: str) -> Optional[Dict]:
17
+ return await self.products.find_one({"_id": ObjectId(product_id)})
18
+
19
+ async def get_products_by_category(self, category: str, limit: int = 10) -> List[Dict]:
20
+ # Your database might have different field names for category
21
+ cursor = self.products.find({"category": category}).limit(limit)
22
+ return await cursor.to_list(length=limit)
23
+
24
+ async def search_products(self, query: Dict, limit: int = 10) -> List[Dict]:
25
+ cursor = self.products.find(query).limit(limit)
26
+ return await cursor.to_list(length=limit)
27
+
28
+ async def get_all_products(self, limit: int = 50) -> List[Dict]:
29
+ cursor = self.products.find().limit(limit)
30
+ return await cursor.to_list(length=limit)
31
+
32
+ async def create_product(self, product_data: Dict) -> str:
33
+ result = await self.products.insert_one(product_data)
34
+ return str(result.inserted_id)
35
+
36
+ async def update_product(self, product_id: str, update_data: Dict) -> bool:
37
+ result = await self.products.update_one(
38
+ {"_id": ObjectId(product_id)},
39
+ {"$set": update_data}
40
+ )
41
+ return result.modified_count > 0
42
+
43
+ async def delete_product(self, product_id: str) -> bool:
44
+ result = await self.products.delete_one({"_id": ObjectId(product_id)})
45
+ return result.deleted_count > 0
46
+
47
+ async def get_conversation(self, conversation_id: str) -> Optional[Dict]:
48
+ return await self.conversations.find_one({"_id": ObjectId(conversation_id)})
49
+
50
+ async def create_conversation(self, user_id: str) -> str:
51
+ conversation_data = {
52
+ "user_id": user_id,
53
+ "messages": [],
54
+ "created_at": datetime.now(),
55
+ "updated_at": datetime.now()
56
+ }
57
+ result = await self.conversations.insert_one(conversation_data)
58
+ return str(result.inserted_id)
59
+
60
+ async def add_message_to_conversation(self, conversation_id: str, message: Dict) -> bool:
61
+ result = await self.conversations.update_one(
62
+ {"_id": ObjectId(conversation_id)},
63
+ {
64
+ "$push": {"messages": message},
65
+ "$set": {"updated_at": datetime.now()}
66
+ }
67
+ )
68
+ return result.modified_count > 0
69
+
70
+ async def get_user_conversations(self, user_id: str, limit: int = 10) -> List[Dict]:
71
+ cursor = self.conversations.find({"user_id": user_id}).sort("updated_at", -1).limit(limit)
72
+ return await cursor.to_list(length=limit)
73
+
74
+ async def store_embedding(self, text: str, embedding: List[float], metadata: Dict) -> str:
75
+ doc = {
76
+ "text": text,
77
+ "embedding": embedding,
78
+ "metadata": metadata,
79
+ "created_at": datetime.now()
80
+ }
81
+ result = await self.embeddings.insert_one(doc)
82
+ return str(result.inserted_id)
83
+
84
+ async def find_similar_embeddings(self, embedding: List[float], limit: int = 5) -> List[Dict]:
85
+ # This is a simplified version - in production, you'd use vector search
86
+ pipeline = [
87
+ {
88
+ "$addFields": {
89
+ "similarity": {
90
+ "$sqrt": {
91
+ "$sum": {
92
+ "$map": {
93
+ "input": {"$range": [0, {"$size": "$embedding"}]},
94
+ "as": "idx",
95
+ "in": {
96
+ "$pow": [
97
+ {"$subtract": [
98
+ {"$arrayElemAt": ["$embedding", "$$idx"]},
99
+ {"$arrayElemAt": [embedding, "$$idx"]}
100
+ ]},
101
+ 2
102
+ ]
103
+ }
104
+ }
105
+ }
106
+ }
107
+ }
108
+ }
109
+ },
110
+ {"$sort": {"similarity": 1}},
111
+ {"$limit": limit}
112
+ ]
113
+ cursor = self.embeddings.aggregate(pipeline)
114
+ return await cursor.to_list(length=limit)
115
+
116
+ # Database instance
117
+ db = MongoDB()
gemini_service.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import google.generativeai as genai
2
+ from typing import List, Dict, Any, Optional
3
+ import asyncio
4
+ import aiohttp
5
+ import json
6
+ from config import settings
7
+
8
+ class GeminiService:
9
+ def __init__(self):
10
+ genai.configure(api_key=settings.gemini_api_key)
11
+ self.model = genai.GenerativeModel('gemini-2.5-flash')
12
+
13
+ async def generate_response(self, prompt: str, context: str = "") -> str:
14
+ """Generate response using Gemini API with context"""
15
+ try:
16
+ full_prompt = f"""
17
+ Context Information:
18
+ {context}
19
+
20
+ User Question: {prompt}
21
+
22
+ You are a helpful shopping assistant for Daddy's Shop. Use the context information above to answer the user's question accurately and helpfully. If the context doesn't contain relevant information, use your general knowledge but be honest about limitations.
23
+
24
+ Provide a friendly, professional response focused on helping with shopping needs.
25
+ """
26
+
27
+ # Run in thread pool since Gemini doesn't have native async support
28
+ loop = asyncio.get_event_loop()
29
+ response = await loop.run_in_executor(
30
+ None,
31
+ lambda: self.model.generate_content(full_prompt)
32
+ )
33
+
34
+ return response.text
35
+
36
+ except Exception as e:
37
+ print(f"Gemini API error: {e}")
38
+ return "I apologize, but I'm having trouble processing your request right now. Please try again later."
39
+
40
+ async def generate_embedding(self, text: str) -> List[float]:
41
+ """Generate embeddings for text using Gemini"""
42
+ try:
43
+ # Note: Gemini doesn't have direct embedding API, so we'll use a workaround
44
+ # For production, consider using SentenceTransformers or another embedding service
45
+ embedding_model = genai.GenerativeModel('embedding-001')
46
+
47
+ loop = asyncio.get_event_loop()
48
+ result = await loop.run_in_executor(
49
+ None,
50
+ lambda: genai.embed_content(
51
+ model=embedding_model,
52
+ content=text,
53
+ task_type="retrieval_document"
54
+ )
55
+ )
56
+
57
+ return result['embedding']
58
+
59
+ except Exception as e:
60
+ print(f"Embedding generation error: {e}")
61
+ # Fallback to simple embedding (in production, use proper embedding model)
62
+ return [0.0] * 384 # Default dimension
63
+
64
+ async def classify_intent(self, message: str) -> Dict[str, Any]:
65
+ """Classify user intent using Gemini"""
66
+ prompt = f"""
67
+ Classify the following user message into one of these intents:
68
+ - product_inquiry: Questions about products, features, availability
69
+ - pricing: Questions about costs, discounts, prices
70
+ - shipping: Questions about delivery, shipping costs, timelines
71
+ - returns: Questions about returns, refunds, exchanges
72
+ - support: General customer support, contact information
73
+ - greeting: Hello, hi, greetings
74
+ - unknown: Cannot classify
75
+
76
+ Message: "{message}"
77
+
78
+ Return ONLY a JSON object with:
79
+ {{
80
+ "intent": "classified_intent",
81
+ "confidence": 0.95,
82
+ "entities": ["extracted_entities", "if_any"]
83
+ }}
84
+ """
85
+
86
+ try:
87
+ loop = asyncio.get_event_loop()
88
+ response = await loop.run_in_executor(
89
+ None,
90
+ lambda: self.model.generate_content(prompt)
91
+ )
92
+
93
+ # Parse JSON response
94
+ import re
95
+ json_match = re.search(r'\{.*\}', response.text, re.DOTALL)
96
+ if json_match:
97
+ return json.loads(json_match.group())
98
+ else:
99
+ return {"intent": "unknown", "confidence": 0.0, "entities": []}
100
+
101
+ except Exception as e:
102
+ print(f"Intent classification error: {e}")
103
+ return {"intent": "unknown", "confidence": 0.0, "entities": []}
104
+
105
+ # Gemini service instance
106
+ gemini_service = GeminiService()
main.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException, Depends
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from typing import List, Optional
4
+ import uuid
5
+ from datetime import datetime
6
+
7
+ from config import settings
8
+ from database import db
9
+ from rag_system import rag_system
10
+ from models import (
11
+ Product, ProductCreate, ChatRequest, ChatResponse,
12
+ SearchRequest, Conversation, ChatMessage
13
+ )
14
+
15
+ app = FastAPI(
16
+ title="Daddy's Shop RAG Chatbot API",
17
+ description="AI-powered shopping assistant with RAG using MongoDB and Gemini",
18
+ version="1.0.0"
19
+ )
20
+
21
+ # CORS middleware
22
+ app.add_middleware(
23
+ CORSMiddleware,
24
+ allow_origins=["http://localhost:3000", "http://127.0.0.1:3000", "https://yourdomain.com"],
25
+ allow_credentials=True,
26
+ allow_methods=["*"],
27
+ allow_headers=["*"],
28
+ )
29
+
30
+ @app.get("/")
31
+ async def root():
32
+ """Health check endpoint"""
33
+ return {
34
+ "message": "Daddy's Shop RAG Chatbot API is running!",
35
+ "version": "1.0.0",
36
+ "status": "healthy"
37
+ }
38
+
39
+ @app.post("/chat", response_model=ChatResponse)
40
+ async def chat_endpoint(request: ChatRequest):
41
+ """
42
+ Main chatbot endpoint with RAG capabilities
43
+ """
44
+ try:
45
+ # Create new conversation if no ID provided
46
+ if not request.conversation_id:
47
+ conversation_id = await db.create_conversation(request.user_id)
48
+ else:
49
+ conversation_id = request.conversation_id
50
+ # Verify conversation exists
51
+ conversation = await db.get_conversation(conversation_id)
52
+ if not conversation:
53
+ conversation_id = await db.create_conversation(request.user_id)
54
+
55
+ # Add user message to conversation
56
+ user_message = ChatMessage(
57
+ sender="user",
58
+ text=request.message
59
+ )
60
+ await db.add_message_to_conversation(conversation_id, user_message.dict())
61
+
62
+ # Get conversation history for context
63
+ conversation = await db.get_conversation(conversation_id)
64
+ history = conversation.get('messages', []) if conversation else []
65
+
66
+ # Generate response using RAG system
67
+ rag_result = await rag_system.generate_chat_response(request.message, history)
68
+
69
+ # Add bot response to conversation
70
+ bot_message = ChatMessage(
71
+ sender="bot",
72
+ text=rag_result["response"]
73
+ )
74
+ await db.add_message_to_conversation(conversation_id, bot_message.dict())
75
+
76
+ return ChatResponse(
77
+ response=rag_result["response"],
78
+ conversation_id=conversation_id,
79
+ suggested_questions=rag_result["suggested_questions"],
80
+ relevant_products=rag_result["relevant_products"],
81
+ intent=rag_result["intent"],
82
+ confidence=rag_result["confidence"]
83
+ )
84
+
85
+ except Exception as e:
86
+ raise HTTPException(status_code=500, detail=f"Chat processing error: {str(e)}")
87
+
88
+ @app.get("/products", response_model=List[Product])
89
+ async def get_products(category: Optional[str] = None, limit: int = 20):
90
+ """Get products with optional category filter"""
91
+ try:
92
+ if category:
93
+ products = await db.get_products_by_category(category, limit)
94
+ else:
95
+ products = await db.get_all_products(limit)
96
+ return products
97
+ except Exception as e:
98
+ raise HTTPException(status_code=500, detail=f"Error fetching products: {str(e)}")
99
+
100
+ @app.get("/products/{product_id}", response_model=Product)
101
+ async def get_product(product_id: str):
102
+ """Get specific product by ID"""
103
+ product = await db.get_product(product_id)
104
+ if not product:
105
+ raise HTTPException(status_code=404, detail="Product not found")
106
+ return product
107
+
108
+ @app.post("/products", response_model=dict)
109
+ async def create_product(product: ProductCreate):
110
+ """Create a new product"""
111
+ try:
112
+ product_data = product.dict()
113
+ product_data["_id"] = str(uuid.uuid4())[:8] # Simple ID generation
114
+ product_id = await db.create_product(product_data)
115
+ return {"message": "Product created successfully", "product_id": product_id}
116
+ except Exception as e:
117
+ raise HTTPException(status_code=500, detail=f"Error creating product: {str(e)}")
118
+
119
+ @app.post("/search", response_model=List[Product])
120
+ async def search_products(request: SearchRequest):
121
+ """Semantic search for products"""
122
+ try:
123
+ products = await rag_system.retrieve_relevant_products(
124
+ request.query,
125
+ request.category,
126
+ request.max_results
127
+ )
128
+ return products
129
+ except Exception as e:
130
+ raise HTTPException(status_code=500, detail=f"Search error: {str(e)}")
131
+
132
+ @app.get("/conversations/{user_id}", response_model=List[Conversation])
133
+ async def get_user_conversations(user_id: str, limit: int = 10):
134
+ """Get user's conversation history"""
135
+ try:
136
+ conversations = await db.get_user_conversations(user_id, limit)
137
+ return conversations
138
+ except Exception as e:
139
+ raise HTTPException(status_code=500, detail=f"Error fetching conversations: {str(e)}")
140
+
141
+ @app.get("/intents")
142
+ async def get_available_intents():
143
+ """Get information about available intents"""
144
+ return {
145
+ "intents": [
146
+ "product_inquiry", "pricing", "shipping",
147
+ "returns", "support", "greeting", "unknown"
148
+ ],
149
+ "description": "Intent classification for user messages"
150
+ }
151
+ @app.get("/test-products")
152
+ async def test_products(limit: int = 3):
153
+ """Test endpoint to see transformed products"""
154
+ try:
155
+ products = await db.get_all_products(limit)
156
+ transformed_products = [rag_system._transform_product_doc(product) for product in products]
157
+ return {
158
+ "original_count": len(products),
159
+ "transformed_count": len(transformed_products),
160
+ "original_sample": products[0] if products else {},
161
+ "transformed_sample": transformed_products[0] if transformed_products else {},
162
+ "all_transformed": transformed_products
163
+ }
164
+ except Exception as e:
165
+ raise HTTPException(status_code=500, detail=f"Test error: {str(e)}")
166
+ # Sample data initialization
167
+ @app.on_event("startup")
168
+ async def startup_event():
169
+ """Initialize sample data if needed"""
170
+ try:
171
+ # Check if we have any products
172
+ products = await db.get_all_products(1)
173
+ if not products:
174
+ await _initialize_sample_data()
175
+ except Exception as e:
176
+ print(f"Startup initialization error: {e}")
177
+
178
+ async def _initialize_sample_data():
179
+ """Initialize sample product data"""
180
+ sample_products = [
181
+ {
182
+ "name": "Wireless Bluetooth Earbuds",
183
+ "description": "High-quality wireless earbuds with noise cancellation and 24-hour battery life.",
184
+ "price": 79.99,
185
+ "category": "electronics",
186
+ "in_stock": True,
187
+ "tags": ["wireless", "bluetooth", "audio", "noise-cancellation"],
188
+ "features": ["Noise Cancellation", "24h Battery", "Water Resistant", "Touch Controls"]
189
+ },
190
+ {
191
+ "name": "Smart Fitness Watch",
192
+ "description": "Advanced fitness tracker with heart rate monitoring, GPS, and smartphone connectivity.",
193
+ "price": 199.99,
194
+ "category": "electronics",
195
+ "in_stock": True,
196
+ "tags": ["fitness", "smartwatch", "health", "tracking"],
197
+ "features": ["Heart Rate Monitor", "GPS", "Sleep Tracking", "Waterproof"]
198
+ },
199
+ {
200
+ "name": "Organic Cotton T-Shirt",
201
+ "description": "Comfortable and sustainable organic cotton t-shirt available in multiple colors.",
202
+ "price": 24.99,
203
+ "category": "clothing",
204
+ "in_stock": True,
205
+ "tags": ["cotton", "organic", "sustainable", "casual"],
206
+ "features": ["Organic Cotton", "Machine Washable", "Multiple Colors"]
207
+ }
208
+ ]
209
+
210
+ for product in sample_products:
211
+ await db.create_product(product)
212
+
213
+ print("Sample data initialized successfully")
214
+
215
+ if __name__ == "__main__":
216
+ import uvicorn
217
+ uvicorn.run(app, host="0.0.0.0", port=settings.port)
models.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel, Field, validator
2
+ from typing import List, Optional, Dict, Any
3
+ from datetime import datetime
4
+ from enum import Enum
5
+
6
+ class ProductCategory(str, Enum):
7
+ ELECTRONICS = "electronics"
8
+ CLOTHING = "clothing"
9
+ HOME = "home"
10
+ BEAUTY = "beauty"
11
+ SPORTS = "sports"
12
+ BOOKS = "books"
13
+ OTHER = "other"
14
+
15
+ class ProductCreate(BaseModel):
16
+ name: str
17
+ description: str = ""
18
+ price: float = 0.0
19
+ category: ProductCategory = ProductCategory.OTHER
20
+ in_stock: bool = True
21
+ tags: List[str] = []
22
+ features: List[str] = []
23
+ image_url: Optional[str] = None
24
+
25
+ @validator('price', pre=True)
26
+ def validate_price(cls, v):
27
+ if v is None:
28
+ return 0.0
29
+ try:
30
+ return float(v)
31
+ except (TypeError, ValueError):
32
+ return 0.0
33
+
34
+ @validator('category', pre=True)
35
+ def validate_category(cls, v):
36
+ if isinstance(v, ProductCategory):
37
+ return v
38
+ try:
39
+ return ProductCategory(v)
40
+ except ValueError:
41
+ return ProductCategory.OTHER
42
+
43
+ class Product(BaseModel):
44
+ id: str
45
+ name: str
46
+ description: str = ""
47
+ price: float = 0.0
48
+ category: ProductCategory = ProductCategory.OTHER
49
+ in_stock: bool = True
50
+ tags: List[str] = []
51
+ features: List[str] = []
52
+ image_url: Optional[str] = None
53
+
54
+ @validator('price', pre=True)
55
+ def validate_price(cls, v):
56
+ if v is None:
57
+ return 0.0
58
+ try:
59
+ return float(v)
60
+ except (TypeError, ValueError):
61
+ return 0.0
62
+
63
+ @validator('category', pre=True)
64
+ def validate_category(cls, v):
65
+ if isinstance(v, ProductCategory):
66
+ return v
67
+ try:
68
+ return ProductCategory(v)
69
+ except ValueError:
70
+ return ProductCategory.OTHER
71
+
72
+ class ChatMessage(BaseModel):
73
+ sender: str
74
+ text: str
75
+ timestamp: datetime = Field(default_factory=datetime.now)
76
+
77
+ class ChatRequest(BaseModel):
78
+ message: str
79
+ conversation_id: Optional[str] = None
80
+ user_id: Optional[str] = "anonymous"
81
+
82
+ class ChatResponse(BaseModel):
83
+ response: str
84
+ conversation_id: str
85
+ suggested_questions: List[str] = []
86
+ relevant_products: List[Product] = []
87
+ intent: str
88
+ confidence: float
89
+
90
+ class SearchRequest(BaseModel):
91
+ query: str
92
+ category: Optional[ProductCategory] = None
93
+ max_results: int = 5
94
+
95
+ class Conversation(BaseModel):
96
+ id: str
97
+ user_id: str
98
+ messages: List[ChatMessage] = []
99
+ created_at: datetime = Field(default_factory=datetime.now)
100
+ updated_at: datetime = Field(default_factory=datetime.now)
rag_system.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any, Optional
2
+ from database import db
3
+ from gemini_service import gemini_service
4
+ from models import Product, ProductCategory
5
+ import numpy as np
6
+ from sentence_transformers import SentenceTransformer
7
+ import asyncio
8
+
9
+ class RAGSystem:
10
+ def __init__(self):
11
+ self.embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
12
+
13
+ def _transform_product_doc(self, product_doc: Dict) -> Dict:
14
+ """Transform MongoDB document to match Product model using actual field names"""
15
+ try:
16
+ # Extract fields from your actual database schema
17
+ product_id = str(product_doc.get('_id', '')) or str(product_doc.get('product_id', ''))
18
+
19
+ # Map to our expected fields
20
+ transformed = {
21
+ "id": product_id,
22
+ "name": product_doc.get('title', 'Unnamed Product'),
23
+ "description": product_doc.get('product_description', 'No description available'),
24
+ "price": float(product_doc.get('final_price', 0.0)),
25
+ "category": self._map_category(product_doc.get('category', 'other')),
26
+ "in_stock": True, # Default to True since we don't have this field
27
+ "tags": self._extract_tags(product_doc),
28
+ "features": self._extract_features(product_doc),
29
+ "image_url": self._get_first_image(product_doc.get('images', '')),
30
+ }
31
+
32
+ return transformed
33
+
34
+ except Exception as e:
35
+ print(f"Error transforming product {product_doc.get('_id')}: {e}")
36
+ # Return a safe default product
37
+ return {
38
+ "id": str(product_doc.get('_id', 'unknown')),
39
+ "name": product_doc.get('title', 'Unnamed Product'),
40
+ "description": "Product information unavailable",
41
+ "price": 0.0,
42
+ "category": ProductCategory.OTHER,
43
+ "in_stock": True,
44
+ "tags": [],
45
+ "features": [],
46
+ }
47
+
48
+ def _map_category(self, raw_category: str) -> ProductCategory:
49
+ """Map raw category string to ProductCategory enum"""
50
+ if not raw_category:
51
+ return ProductCategory.OTHER
52
+
53
+ category_lower = raw_category.lower()
54
+
55
+ # Map based on your actual category values
56
+ if any(keyword in category_lower for keyword in ['electronic', 'tech', 'computer', 'phone']):
57
+ return ProductCategory.ELECTRONICS
58
+ elif any(keyword in category_lower for keyword in ['cloth', 'fashion', 'wear', 'top', 'dress', 'shirt', 'jeans']):
59
+ return ProductCategory.CLOTHING
60
+ elif any(keyword in category_lower for keyword in ['home', 'garden', 'furniture', 'decor']):
61
+ return ProductCategory.HOME
62
+ elif any(keyword in category_lower for keyword in ['beauty', 'cosmetic', 'skin', 'hair', 'cream', 'mask', 'makeup']):
63
+ return ProductCategory.BEAUTY
64
+ elif any(keyword in category_lower for keyword in ['sport', 'fitness', 'exercise', 'gym']):
65
+ return ProductCategory.SPORTS
66
+ elif any(keyword in category_lower for keyword in ['book', 'literature']):
67
+ return ProductCategory.BOOKS
68
+ else:
69
+ return ProductCategory.OTHER
70
+
71
+ def _extract_tags(self, product_doc: Dict) -> List[str]:
72
+ """Extract tags from product details"""
73
+ tags = []
74
+
75
+ # Add category as a tag
76
+ category = product_doc.get('category')
77
+ if category:
78
+ tags.append(category)
79
+
80
+ # Extract from product details if available
81
+ details = product_doc.get('product_details', '')
82
+ if 'Cotton' in details:
83
+ tags.append('Cotton')
84
+ if 'Polyester' in details:
85
+ tags.append('Polyester')
86
+
87
+ return tags
88
+
89
+ def _extract_features(self, product_doc: Dict) -> List[str]:
90
+ """Extract features from product details"""
91
+ features = []
92
+
93
+ # Extract key features from product details
94
+ details = product_doc.get('product_details', '')
95
+
96
+ # Look for common features
97
+ feature_keywords = [
98
+ 'Machine wash', 'Hand wash', 'Dry clean', 'Cotton',
99
+ 'Polyester', 'Elastane', 'Spaghetti', 'Sleeveless',
100
+ 'Solid pattern', 'Sweetheart neck'
101
+ ]
102
+
103
+ for keyword in feature_keywords:
104
+ if keyword in details:
105
+ features.append(keyword)
106
+
107
+ # Add rating as a feature
108
+ rating = product_doc.get('rating')
109
+ if rating:
110
+ features.append(f'{rating}★ Rating')
111
+
112
+ # Add delivery options
113
+ delivery = product_doc.get('delivery_options', '')
114
+ if 'Pay on delivery' in delivery:
115
+ features.append('Pay on Delivery')
116
+ if 'Easy returns' in delivery:
117
+ features.append('Easy Returns')
118
+
119
+ return features
120
+
121
+ def _get_first_image(self, images_str: str) -> Optional[str]:
122
+ """Extract first image URL from comma-separated string"""
123
+ if not images_str:
124
+ return None
125
+ images = images_str.split(',')
126
+ return images[0].strip() if images else None
127
+
128
+ # ... rest of your existing methods remain the same
129
+ async def retrieve_relevant_products(self, query: str, category: Optional[str] = None, limit: int = 5) -> List[Dict]:
130
+ """Retrieve and transform relevant products"""
131
+ try:
132
+ # Generate query embedding
133
+ query_embedding = await self._get_embedding(query)
134
+
135
+ # Get all products or filtered by category
136
+ if category:
137
+ products = await db.get_products_by_category(category, limit=50)
138
+ else:
139
+ products = await db.get_all_products(limit=100)
140
+
141
+ # Transform products first
142
+ transformed_products = [self._transform_product_doc(product) for product in products]
143
+
144
+ # Calculate similarity scores on transformed products
145
+ scored_products = []
146
+ for product in transformed_products:
147
+ product_text = f"{product.get('name', '')} {product.get('description', '')} {' '.join(product.get('tags', []))}"
148
+ product_embedding = await self._get_embedding(product_text)
149
+
150
+ similarity = self._cosine_similarity(query_embedding, product_embedding)
151
+ scored_products.append((product, similarity))
152
+
153
+ # Sort by similarity and return top results
154
+ scored_products.sort(key=lambda x: x[1], reverse=True)
155
+ return [product for product, score in scored_products[:limit]]
156
+
157
+ except Exception as e:
158
+ print(f"Product retrieval error: {e}")
159
+ # Fallback to simple keyword search with transformation
160
+ return await self._keyword_search_products(query, category, limit)
161
+
162
+ async def _keyword_search_products(self, query: str, category: Optional[str], limit: int) -> List[Dict]:
163
+ """Fallback keyword-based product search with transformation"""
164
+ search_terms = query.lower().split()
165
+
166
+ if category:
167
+ products = await db.get_products_by_category(category, limit=50)
168
+ else:
169
+ products = await db.get_all_products(limit=100)
170
+
171
+ # Transform all products first
172
+ transformed_products = [self._transform_product_doc(product) for product in products]
173
+
174
+ scored_products = []
175
+ for product in transformed_products:
176
+ score = 0
177
+ product_text = f"{product.get('name', '')} {product.get('description', '')} {' '.join(product.get('tags', []))}".lower()
178
+
179
+ for term in search_terms:
180
+ if term in product_text:
181
+ score += 1
182
+ if term in product.get('name', '').lower():
183
+ score += 2 # Higher weight for name matches
184
+
185
+ if score > 0:
186
+ scored_products.append((product, score))
187
+
188
+ scored_products.sort(key=lambda x: x[1], reverse=True)
189
+ return [product for product, score in scored_products[:limit]]
190
+
191
+ async def _get_embedding(self, text: str) -> List[float]:
192
+ """Get embedding for text using available methods"""
193
+ try:
194
+ embedding = await gemini_service.generate_embedding(text)
195
+ if embedding and len(embedding) > 10:
196
+ return embedding
197
+ except:
198
+ pass
199
+
200
+ embedding = self.embedding_model.encode(text)
201
+ return embedding.tolist()
202
+
203
+ def _cosine_similarity(self, vec1: List[float], vec2: List[float]) -> float:
204
+ """Calculate cosine similarity between two vectors"""
205
+ vec1 = np.array(vec1)
206
+ vec2 = np.array(vec2)
207
+
208
+ dot_product = np.dot(vec1, vec2)
209
+ norm1 = np.linalg.norm(vec1)
210
+ norm2 = np.linalg.norm(vec2)
211
+
212
+ if norm1 == 0 or norm2 == 0:
213
+ return 0.0
214
+
215
+ return dot_product / (norm1 * norm2)
216
+
217
+ async def build_context(self, query: str, relevant_products: List[Dict]) -> str:
218
+ """Build context string from relevant products for the LLM"""
219
+ if not relevant_products:
220
+ return "No specific product information available. Use general knowledge about e-commerce and shopping."
221
+
222
+ context_parts = ["Relevant Products Information:"]
223
+
224
+ for i, product in enumerate(relevant_products, 1):
225
+ context_parts.append(f"""
226
+ Product {i}:
227
+ - Name: {product.get('name', 'N/A')}
228
+ - Description: {product.get('description', 'N/A')}
229
+ - Price: ${product.get('price', 'N/A')}
230
+ - Category: {product.get('category', 'N/A')}
231
+ - In Stock: {'Yes' if product.get('in_stock') else 'No'}
232
+ - Features: {', '.join(product.get('features', []))}
233
+ - Tags: {', '.join(product.get('tags', []))}
234
+ """)
235
+
236
+ return "\n".join(context_parts)
237
+
238
+ async def generate_chat_response(self, user_message: str, conversation_history: List[Dict] = None) -> Dict[str, Any]:
239
+ """Generate response using RAG pipeline"""
240
+ # Classify intent
241
+ intent_result = await gemini_service.classify_intent(user_message)
242
+
243
+ # Retrieve relevant products (already transformed)
244
+ relevant_products = await self.retrieve_relevant_products(user_message, limit=3)
245
+
246
+ # Build context
247
+ context = await self.build_context(user_message, relevant_products)
248
+
249
+ # Generate response using Gemini with context
250
+ response = await gemini_service.generate_response(user_message, context)
251
+
252
+ # Generate suggested questions based on intent
253
+ suggested_questions = await self._generate_suggested_questions(intent_result['intent'], relevant_products)
254
+
255
+ return {
256
+ "response": response,
257
+ "relevant_products": relevant_products,
258
+ "suggested_questions": suggested_questions,
259
+ "intent": intent_result['intent'],
260
+ "confidence": intent_result['confidence']
261
+ }
262
+
263
+ async def _generate_suggested_questions(self, intent: str, products: List[Dict]) -> List[str]:
264
+ """Generate context-aware suggested questions"""
265
+ base_questions = {
266
+ "product_inquiry": ["Show me more products", "What are your best sellers?", "Any current deals?"],
267
+ "pricing": ["Do you offer discounts?", "What's the return policy?", "Any bundle deals?"],
268
+ "shipping": ["How long does delivery take?", "Do you ship internationally?", "What are shipping costs?"],
269
+ "returns": ["How do I return an item?", "What's your warranty policy?", "Do you offer exchanges?"],
270
+ "support": ["Contact customer service", "Store locations", "Business hours"],
271
+ "default": ["Best sellers", "Current deals", "Shipping info", "Return policy"]
272
+ }
273
+
274
+ questions = base_questions.get(intent, base_questions["default"])
275
+
276
+ # Add product-specific questions if we have products
277
+ if products and intent == "product_inquiry":
278
+ categories = list(set(str(p.get('category')) for p in products))
279
+ if categories:
280
+ questions = [f"Show me more {cat} products" for cat in categories[:2]] + questions
281
+
282
+ return questions[:4]
283
+
284
+ # RAG system instance
285
+ rag_system = RAGSystem()
requirements.txt ADDED
Binary file (3 kB). View file
 
run.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import uvicorn
2
+ from config import settings
3
+
4
+ if __name__ == "__main__":
5
+ uvicorn.run(
6
+ "main:app",
7
+ host="0.0.0.0",
8
+ port=settings.port,
9
+ reload=True, # Enable auto-reload during development
10
+ log_level="info"
11
+ )