Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import FileResponse | |
| from pydantic import BaseModel | |
| from openai import OpenAI | |
| from datetime import datetime | |
| from pathlib import Path | |
| import requests | |
| import json | |
| import re | |
| import os | |
| app = FastAPI( | |
| title="Store Product Search API", | |
| description=( | |
| "Three-step RAG pipeline for natural-language product search:\n" | |
| " 1. User message β NVIDIA Llama 3.1 70B (SQL generator)\n" | |
| " 2. Generated SQL β PHP DB layer (execution)\n" | |
| " 3. Query results + user message β NVIDIA Llama 3.1 70B (natural language reply)" | |
| ), | |
| ) | |
| # ββ CORS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ββ Configuration ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| NVIDIA_API_KEY = os.environ.get("NVIDIA_API_KEY", "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") # Set as env/secret | |
| PHP_DB_URL = os.environ.get("PHP_DB_URL", "https://ctkart.com/api/db_api.php") | |
| INTERNAL_SECRET = os.environ.get("INTERNAL_SECRET", "change_this_secret_in_production") | |
| # Output directory for SQL query results | |
| OUTPUT_DIR = Path("./sql_results") | |
| OUTPUT_DIR.mkdir(exist_ok=True) | |
| # NVIDIA LLM Configuration (for SQL generation and natural language replies) | |
| LLM_URL = "https://integrate.api.nvidia.com/v1" | |
| LLM_MODEL = "meta/llama-3.1-70b-instruct" | |
| # Validate API key | |
| if not NVIDIA_API_KEY: | |
| raise RuntimeError( | |
| "β NVIDIA_API_KEY is not set!\n" | |
| "Set it with: export NVIDIA_API_KEY='your_actual_key'\n" | |
| "Or add to .env: NVIDIA_API_KEY=your_actual_key" | |
| ) | |
| # Initialize OpenAI client with NVIDIA base URL (OpenAI-compatible API) | |
| client = OpenAI( | |
| base_url=LLM_URL, | |
| api_key=NVIDIA_API_KEY | |
| ) | |
| print(f"β NVIDIA API configured with model: {LLM_MODEL}") | |
| # Anthropic kept for reference (can be removed if fully migrating to NVIDIA) | |
| # ANTHROPIC_URL = "https://api.anthropic.com/v1/messages" | |
| # ANTHROPIC_MODEL = "claude-sonnet-4-20250514" | |
| # ANTHROPIC_HEADERS = {...} | |
| PHP_HEADERS = { | |
| "Content-Type": "application/json", | |
| "X-Internal-Secret": INTERNAL_SECRET, | |
| } | |
| # ββ Database schema (injected into the SQL-generation prompt) βββββββββββββββββ | |
| DB_SCHEMA = """ | |
| Table: item_master (main product table) | |
| Columns: | |
| id INT Primary key, auto-increment | |
| item_name VARCHAR(1000) Product name (e.g. 'Bag-B27-BLACK', 'LADIES PURSE B-55', 'WS-253 BLACK PU-SLIPPER') | |
| category_id INT FK β category (see known category IDs below) | |
| subcategory_id INT FK β subcategory | |
| BrandID INT FK β brands | |
| VendorID INT FK β vendor | |
| store_id INT Always 1 in this dataset | |
| mrp DOUBLE MRP price in INR (e.g. 496, 410, 599) | |
| hsn VARCHAR(50) HSN/tax code (e.g. '4202' for bags, '6405' for footwear, '61112000' for kids clothes, '5407' for sarees) | |
| size_dimension VARCHAR(45) Physical dimensions or clothing size (e.g. 'W-24 X H-17.5 X B-6', '4, 5, 6, 7', '26, 28, 30') | |
| weight DECIMAL(11,2) Weight in grams | |
| color VARCHAR(45) Hex color code (e.g. '#000000' for black, '#ff0000' for red, '#1b1b18' for near-black) | |
| packingtype VARCHAR(30) Packing type: 'PCS', 'Box', 'Other' | |
| packingtime INT Days to pack | |
| tax_p DOUBLE Tax percentage (18 for bags/footwear, 5 for clothing) | |
| reorder_qty INT Reorder quantity threshold | |
| Qty INT Current stock quantity (filter Qty >= 0 for in-stock) | |
| stock_movement VARCHAR(50) Always 'FIFO' | |
| saleprice DOUBLE Actual selling price in INR (usually mrp/2; use this as the display price) | |
| dis_p DOUBLE Discount percentage (0β100; if > 0 item is on discount) | |
| description TEXT Product description | |
| status VARCHAR(20) 'Active' or 'Inactive' β always filter WHERE status = 'Active' | |
| ongoing_offer VARCHAR(11) 'yes' or 'no' | |
| discount_percentage VARCHAR(11) Additional discount label | |
| Known category_id values (approximate β use LIKE on item_name for category filtering): | |
| 1 = Girls' Frocks / Baby Frocks | |
| 3 = Sarees (chiffon, dola silk) | |
| 9 = Men's Footwear (slippers) | |
| 10 = Women's Footwear (sandals, slippers, heels, flip-flops) | |
| 11 = Kids' Sets / Combo (boy/girl outfit sets, kids' purses) | |
| 13 = Bags (backpacks, shoulder bags, sling bags, combo bags) | |
| 19 = Shoulder Bags / Sling Bags (women) | |
| 20 = Ladies Purses / Clutches / Wallets | |
| 21 = Mobile Side Bags / Cross Body Bags | |
| 22 = Sling Bags / Cross Body Bags | |
| 23 = Handbags / Clutches / Ladies Sling Bags | |
| Color notes: color is stored as hex (#000000 = Black, #ffffff = White). | |
| For color search by name, use item_name LIKE '%COLOR%' (e.g. item_name LIKE '%BLACK%'). | |
| Table: product_images (product photos) | |
| Columns: | |
| image_id INT Primary key | |
| product_id INT FK β item_master.id | |
| path_url VARCHAR(100) Filename (e.g. '83_1769753521_01.avif') β prepend your base image URL | |
| default_img VARCHAR(1) 'y' = primary/default image, 'n' = additional image | |
| img_seq INT Display order (1 = first shown) | |
| To get the images for a product: | |
| LEFT JOIN product_images ON item_master.id = product_images.product_id | |
| Table: brands (brand information) | |
| Columns: | |
| BrandID INT Primary key | |
| Brand VARCHAR(100) Brand name (e.g. 'Nike', 'Adidas', 'Reebok') | |
| Table: category (product categories) | |
| Columns: | |
| category_id INT Primary key | |
| category_name VARCHAR(100) Category name (e.g. 'Men's Footwear', '') | |
| Table: subcategory (product subcategories) | |
| Columns: | |
| subcategoryid INT Primary key | |
| subcategory VARCHAR(100) Subcategory name (e.g. 'Slippers', 'Sarees', 'Kids Sets') | |
| Table: vendor (vendor information) | |
| Columns: | |
| ID INT Primary key | |
| Name VARCHAR(100) Vendor name (e.g. 'Vendor A', 'Vendor B') | |
| """ | |
| # ββ System prompts βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| SQL_SYSTEM_PROMPT = f"""You are an expert MySQL query generator for an online fashion store. | |
| DATABASE SCHEMA: | |
| {DB_SCHEMA} | |
| Your task: | |
| - Analyse the user's natural language product search request. | |
| - Generate a single valid MySQL SELECT query that retrieves matching products. | |
| - ALWAYS include these columns: id, name, category, brand, color, size, gender, price, discount_pct, stock, rating, description | |
| - Apply WHERE filters based on what the user asked (price, color, size, category, gender, etc.) | |
| - Use LIKE for partial text matches on name, category, brand, color. | |
| - For price constraints: "within 500", "under 500", "below 500" β price <= 500; "above 500" β price >= 500; "between X and Y" β price BETWEEN X AND Y | |
| - For size: match exactly. Map "XXL" β size = 'XXL', "extra large" β size = 'XL', etc. | |
| - For color: use LIKE '%Blue%' (case-insensitive intent). | |
| - For product image: use LEFT JOIN product_images ON item_master.id = product_images.product_id AND product_images.default_img = 'y' to get the main image URL. | |
| - Always filter WHERE stock >= 0 (only show in-stock items). | |
| - Add ORDER BY rating DESC, price ASC to surface best value first. | |
| - LIMIT results to 10 rows maximum. | |
| - Return ONLY the raw SQL query β no markdown fences, no explanation, no preamble. Just the SQL. | |
| Rules: | |
| - Never use DROP, DELETE, INSERT, UPDATE, ALTER, TRUNCATE, or any DML/DDL. | |
| - Only SELECT from the item_master table and joined tables. | |
| - If the user's request is vague (e.g. "I want a nice dress"), generate a broad query that returns popular/relevant items (e.g. SELECT ... WHERE category LIKE '%dress%' OR name LIKE '%dress%' ORDER BY rating DESC LIMIT 10). | |
| - Don't include Qty column of item_master in WHERE clause. | |
| - If the request is ambiguous, generate a broad but relevant query. | |
| """ | |
| RESPONSE_SYSTEM_PROMPT = """You are a helpful, friendly online fashion store assistant named "ShopBot". | |
| Your task: | |
| - You receive the user's original search query and a JSON array of matching products from the database. | |
| - Respond in a warm, conversational tone β like a knowledgeable sales assistant. | |
| - Summarize what was found, highlight 2β3 standout products with specific details (name, price, color, size, rating). | |
| - If no products are found, suggest alternatives or ask clarifying questions. | |
| - Keep responses concise (3β5 sentences max unless listing products). | |
| - Format product listings clearly: use "β" bullet style. | |
| - Always mention price in βΉ (Indian Rupees). | |
| - Do NOT mention SQL, databases, or internal systems to the user. | |
| - End with a helpful follow-up question or offer (e.g., "Want me to filter by a specific color?"). | |
| """ | |
| # ββ Request / Response models ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class SearchRequest(BaseModel): | |
| message: str | |
| conversation_history: list = [] # list of {"role": "user"|"assistant", "content": "..."} | |
| class SearchResponse(BaseModel): | |
| reply: str | |
| products: list | |
| generated_sql: str | |
| row_count: int | |
| # ββ Helpers βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def generate_sql(user_message: str) -> str: | |
| """Step 1: Ask NVIDIA Llama to generate SQL from the user's natural language query.""" | |
| try: | |
| response = client.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=[ | |
| {"role": "system", "content": SQL_SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"User search query: {user_message}\n\nGenerate the SQL SELECT query."}, | |
| ], | |
| temperature=0.2, | |
| top_p=0.7, | |
| max_tokens=512, | |
| ) | |
| raw = response.choices[0].message.content.strip() | |
| except Exception as e: | |
| raise HTTPException(status_code=502, detail=f"NVIDIA SQL-gen error: {str(e)}") | |
| print("SQL-gen response:", raw) | |
| # Strip any accidental markdown fences | |
| raw = re.sub(r"```sql\s*", "", raw, flags=re.IGNORECASE) | |
| raw = re.sub(r"```\s*", "", raw).strip() | |
| if not raw.upper().startswith("SELECT"): | |
| raise HTTPException( | |
| status_code=502, | |
| detail=f"SQL generator did not return a SELECT statement. Got: {raw[:200]}", | |
| ) | |
| print("Generated SQL:", raw) | |
| return raw | |
| def execute_sql(sql: str) -> dict: | |
| """Step 2: Send SQL to the PHP DB layer and get results.""" | |
| try: | |
| resp = requests.post( | |
| PHP_DB_URL, | |
| headers=PHP_HEADERS, | |
| json={"sql": sql}, | |
| timeout=15, | |
| ) | |
| resp.raise_for_status() | |
| except requests.exceptions.RequestException as e: | |
| raise HTTPException(status_code=502, detail=f"PHP DB layer error: {str(e)}") | |
| result = resp.json() | |
| print("DB result:", result) | |
| if not result.get("success", False): | |
| raise HTTPException( | |
| status_code=422, | |
| detail=f"Query failed: {result.get('error', 'Unknown DB error')}", | |
| ) | |
| # Save results to file with timestamp | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| filename = OUTPUT_DIR / f"sql_result_{timestamp}.json" | |
| output_data = { | |
| "timestamp": datetime.now().isoformat(), | |
| "sql_query": sql, | |
| "results": result.get("results", []), | |
| "row_count": result.get("row_count", 0), | |
| "execution_status": "success" if result.get("success") else "failed" | |
| } | |
| try: | |
| with open(filename, "w", encoding="utf-8") as f: | |
| json.dump(output_data, f, ensure_ascii=False, indent=2) | |
| print(f"β Results saved to: {filename}") | |
| except Exception as e: | |
| print(f"β οΈ Failed to save results to file: {str(e)}") | |
| return result | |
| def generate_reply(user_message: str, products: list, conversation_history: list) -> str: | |
| """Step 3: Ask NVIDIA Llama to turn the query results into a friendly response.""" | |
| products_json = json.dumps(products, ensure_ascii=False, indent=2) | |
| # Build messages: system + history + current turn | |
| messages = [{"role": "system", "content": RESPONSE_SYSTEM_PROMPT}] | |
| # Add conversation history | |
| messages.extend(conversation_history) | |
| # Add current user query | |
| messages.append({ | |
| "role": "user", | |
| "content": ( | |
| f"User query: {user_message}\n\n" | |
| f"Product search results ({len(products)} items found):\n{products_json}\n\n" | |
| f"Please give a helpful, friendly response to the user." | |
| ), | |
| }) | |
| try: | |
| response = client.chat.completions.create( | |
| model=LLM_MODEL, | |
| messages=messages, | |
| temperature=0.2, | |
| top_p=0.7, | |
| max_tokens=1024, | |
| ) | |
| reply = response.choices[0].message.content.strip() | |
| except Exception as e: | |
| raise HTTPException(status_code=502, detail=f"NVIDIA reply-gen error: {str(e)}") | |
| print("Reply response:", reply) | |
| if not reply: | |
| raise HTTPException(status_code=502, detail="LLM returned an empty reply") | |
| return reply | |
| # ββ Endpoint βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| async def product_search(req: SearchRequest): | |
| """ | |
| Natural-language product search endpoint. | |
| Pipeline: | |
| 1. user message β Claude (SQL generator) β SQL SELECT statement | |
| 2. SQL β PHP DB layer β JSON product rows | |
| 3. products JSON + user message β Claude (ShopBot) β friendly reply | |
| Body: | |
| - message: the user's search text (e.g. "I need a blue XXL shirt under 500") | |
| - conversation_history: optional list of prior turns for context | |
| """ | |
| if not req.message.strip(): | |
| raise HTTPException(status_code=400, detail="Message cannot be empty") | |
| # Step 1: Generate SQL | |
| sql = generate_sql(req.message) | |
| # Step 2: Execute against DB | |
| db_result = execute_sql(sql) | |
| products = db_result.get("results", []) | |
| row_count = db_result.get("row_count", 0) | |
| # Step 3: Generate friendly response | |
| reply = generate_reply(req.message, products, req.conversation_history) | |
| return SearchResponse( | |
| reply=reply, | |
| products=products, | |
| generated_sql=sql, | |
| row_count=row_count, | |
| ) | |
| async def health(): | |
| return {"status": "healthy", "model": LLM_MODEL, "provider": "NVIDIA"} | |
| async def list_results(): | |
| """List all saved SQL query results.""" | |
| try: | |
| files = sorted(OUTPUT_DIR.glob("sql_result_*.json"), reverse=True) | |
| results_info = [] | |
| for f in files[:20]: # Last 20 results | |
| results_info.append({ | |
| "filename": f.name, | |
| "path": str(f), | |
| "created": f.stat().st_mtime | |
| }) | |
| return {"status": "success", "results": results_info, "total": len(results_info)} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to list results: {str(e)}") | |
| async def get_result(filename: str): | |
| """Retrieve a specific saved SQL query result.""" | |
| filepath = OUTPUT_DIR / filename | |
| if not filepath.exists() or not filepath.suffix == ".json": | |
| raise HTTPException(status_code=404, detail="Result file not found") | |
| try: | |
| return FileResponse(filepath, media_type="application/json", filename=filename) | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=f"Failed to retrieve result: {str(e)}") | |
| async def root(): | |
| return FileResponse("index.html") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("app:app", host="0.0.0.0", port=8000, reload=True) | |