File size: 4,838 Bytes
45ecbbd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
"""NLP Parser - Extract structured search parameters from natural language."""
import json
from huggingface_hub import InferenceClient
from config import HF_TOKEN, LLM_MODEL


def parse_user_request(text):
    """
    Parse natural language shopping request into structured parameters.
    
    Args:
        text: User's plain-English request
    
    Returns:
        dict: {
            "searches": [{"query": str, "max_price": float, ...}, ...],
            "requirements": [str, str, ...]
        }
    """
    if not text.strip():
        return {"searches": [], "requirements": []}
    
    system_prompt = """You are an expert shopping assistant parser. Given the user's natural language request, return JSON with two keys:

"searches": a list of objects, one per distinct product the user wants. Each object has:
  - query: search keywords (str, required)
  - category: one of [Electronics, Clothing & Apparel, Home & Garden, Health & Beauty, Sports & Outdoors, Toys & Games, Books & Media, Office & School, Food & Grocery, Auto & Parts] or null
  - min_price: number or null
  - max_price: number or null
  - sort_by: "relevance"|"price_low"|"price_high"|"rating"|null
  - brand: str or null
  - store: str or null

"requirements": a list of strings — specific criteria the user mentioned that go BEYOND standard filters. These are things you would need to read a product description or spec sheet to verify. Examples:
  - "espresso only — not drip or pour-over"
  - "manufactured in USA or Italy"
  - "burr grinder, not blade"
  - "BPA-free materials"
  - "compatible with K-cups"
  - "must have HDMI 2.1 port"
  - "vibration pump"
  - "water reservoir at least 1 liter"

Do NOT include price or brand here (those are already in the search object). Only include requirements that need spec-sheet verification.

Return ONLY valid JSON, no commentary."""

    user_message = f"User request: {text}"
    
    try:
        client = InferenceClient(token=HF_TOKEN)
        
        # Call the LLM
        response = client.chat_completion(
            model=LLM_MODEL,
            messages=[
                {"role": "system", "content": system_prompt},
                {"role": "user", "content": user_message}
            ],
            max_tokens=1000,
            temperature=0.3,
        )
        
        # Extract the response text
        response_text = response.choices[0].message.content.strip()
        
        # Try to extract JSON from the response
        parsed = _extract_json(response_text)
        
        # Validate and fill in missing fields
        validated = _validate_response(parsed)
        
        print(f"NLP Parser extracted: {json.dumps(validated, indent=2)}")
        return validated
        
    except Exception as e:
        print(f"Error in NLP parsing: {e}")
        import traceback
        traceback.print_exc()
        
        # Fallback: treat the text as a simple search query
        return {
            "searches": [{"query": text, "category": None, "min_price": None, 
                         "max_price": None, "sort_by": None, "brand": None, "store": None}],
            "requirements": []
        }


def _extract_json(text):
    """Extract JSON from LLM response that might have extra text."""
    # Try to find JSON block
    start = text.find('{')
    end = text.rfind('}') + 1
    
    if start != -1 and end > start:
        json_str = text[start:end]
        try:
            return json.loads(json_str)
        except json.JSONDecodeError:
            pass
    
    # If that fails, try the whole text
    try:
        return json.loads(text)
    except json.JSONDecodeError:
        return {}


def _validate_response(data):
    """Validate and fill in missing fields."""
    if not isinstance(data, dict):
        return {"searches": [], "requirements": []}
    
    # Ensure searches is a list
    searches = data.get("searches", [])
    if not isinstance(searches, list):
        searches = []
    
    # Validate each search object
    validated_searches = []
    for search in searches:
        if not isinstance(search, dict):
            continue
        
        validated_searches.append({
            "query": search.get("query", ""),
            "category": search.get("category"),
            "min_price": search.get("min_price"),
            "max_price": search.get("max_price"),
            "sort_by": search.get("sort_by"),
            "brand": search.get("brand"),
            "store": search.get("store"),
        })
    
    # Ensure requirements is a list of strings
    requirements = data.get("requirements", [])
    if not isinstance(requirements, list):
        requirements = []
    
    requirements = [str(r) for r in requirements if r]
    
    return {
        "searches": validated_searches,
        "requirements": requirements
    }