Narayana02 commited on
Commit
4de92a1
·
verified ·
1 Parent(s): d18ee6a

Update fashion_query.py

Browse files
Files changed (1) hide show
  1. fashion_query.py +41 -39
fashion_query.py CHANGED
@@ -1,10 +1,11 @@
1
- import os
2
  import pandas as pd
3
  import random
 
4
  from transformers import DistilBertTokenizer, DistilBertModel
5
  from huggingface_hub import login
 
6
 
7
- # Check if the token is available
8
  HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
9
  if not HUGGINGFACE_TOKEN:
10
  raise EnvironmentError("HUGGINGFACE_TOKEN environment variable is not set.")
@@ -17,7 +18,16 @@ model_name = "distilbert-base-uncased"
17
  tokenizer = DistilBertTokenizer.from_pretrained(model_name, use_auth_token=HUGGINGFACE_TOKEN)
18
  model = DistilBertModel.from_pretrained(model_name, use_auth_token=HUGGINGFACE_TOKEN)
19
 
20
- # Load the fashion dataset
 
 
 
 
 
 
 
 
 
21
  def load_fashion_dataset():
22
  try:
23
  fashion_df = pd.read_csv('fashion.csv')
@@ -31,41 +41,28 @@ def load_fashion_dataset():
31
  except Exception as e:
32
  raise Exception(f"Error loading fashion dataset: {e}")
33
 
34
- # Define different response styles
35
- fashion_response_templates = [
36
- lambda row: (
37
- f"The {row['ProductName']} is perfect for a stylish {row['Category']}. "
38
- f"Available for ${row['Price']}, it's known for its {row['Description']}. "
39
- f"Would you like more details on the {row['ProductName']}?"
40
- ),
41
- lambda row: (
42
- f"Discover the {row['ProductName']}! Priced at ${row['Price']} and famous for its {row['Description']}. "
43
- f"What are your thoughts on this {row['Category']}?"
44
- ),
45
- # Add more templates as needed
46
- ]
47
-
48
- # Generate a dynamic response
49
  def generate_fashion_response(row):
50
- template = random.choice(fashion_response_templates)
 
 
 
 
 
 
 
51
  return template(row)
52
 
53
- # Extract filters from the query
54
  def extract_fashion_filters(query):
55
  filters = {}
56
  query_lower = query.lower()
57
-
58
  if 'best' in query_lower and 'rating' in query_lower:
59
  filters['Rating'] = 'max'
60
-
61
  if 'dresses' in query_lower:
62
  filters['Category'] = 'dress'
63
  elif 'shoes' in query_lower:
64
  filters['Category'] = 'shoes'
65
-
66
  return filters
67
 
68
- # Apply filters to the fashion DataFrame
69
  def apply_fashion_filters(df, filters):
70
  for key, value in filters.items():
71
  if key == 'Rating' and value == 'max':
@@ -74,30 +71,35 @@ def apply_fashion_filters(df, filters):
74
  df = df[df[key].str.contains(value, case=False, na=False)]
75
  return df
76
 
77
- # Query fashion based on user input
78
  def query_fashion(user_query, n_results=5):
79
  fashion_df = load_fashion_dataset()
80
  filtered_df = apply_fashion_filters(fashion_df, extract_fashion_filters(user_query))
81
-
82
- # Check if 'Rating' column exists before sorting
83
  if 'Rating' in filtered_df.columns:
84
  sorted_df = filtered_df.sort_values(by='Rating', ascending=False)
85
  else:
86
- sorted_df = filtered_df # Skip sorting if 'Rating' is not available
87
-
88
- # Return the top N results
89
  return sorted_df.head(n_results)
90
 
91
- # Function to integrate the chatbot responses
92
- def fashion_chatbot_response(user_query):
93
- # Query the fashion data
94
- fashion_results = query_fashion(user_query)
 
 
 
 
 
 
 
 
 
95
 
 
 
96
  if not fashion_results.empty:
97
- response = ""
98
  for _, row in fashion_results.iterrows():
99
- response += generate_fashion_response(row) + "\n"
100
- return response
 
101
 
102
- # If no results, generate a fallback response
103
- return "Sorry, I couldn't find any fashion items matching your query."
 
 
1
  import pandas as pd
2
  import random
3
+ import os
4
  from transformers import DistilBertTokenizer, DistilBertModel
5
  from huggingface_hub import login
6
+ import torch
7
 
8
+ # Load environment variables
9
  HUGGINGFACE_TOKEN = os.getenv('HUGGINGFACE_TOKEN')
10
  if not HUGGINGFACE_TOKEN:
11
  raise EnvironmentError("HUGGINGFACE_TOKEN environment variable is not set.")
 
18
  tokenizer = DistilBertTokenizer.from_pretrained(model_name, use_auth_token=HUGGINGFACE_TOKEN)
19
  model = DistilBertModel.from_pretrained(model_name, use_auth_token=HUGGINGFACE_TOKEN)
20
 
21
+ def predict_class(query, system_message):
22
+ """Predict the class of the query based on the provided system message."""
23
+ inputs = tokenizer(system_message + " " + query, return_tensors='pt')
24
+ with torch.no_grad():
25
+ outputs = model(**inputs)
26
+ logits = outputs.last_hidden_state[:, 0, :] # Using [CLS] token's embedding
27
+ probabilities = torch.nn.functional.softmax(logits, dim=-1)
28
+ predicted_class = torch.argmax(probabilities).item()
29
+ return predicted_class
30
+
31
  def load_fashion_dataset():
32
  try:
33
  fashion_df = pd.read_csv('fashion.csv')
 
41
  except Exception as e:
42
  raise Exception(f"Error loading fashion dataset: {e}")
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  def generate_fashion_response(row):
45
+ templates = [
46
+ lambda r: (f"The {r['ProductName']} is perfect for a stylish {r['Category']}. "
47
+ f"Available for ${r['Price']}, it's known for its {r['Description']}. "
48
+ f"Would you like more details on the {r['ProductName']}?"),
49
+ lambda r: (f"Discover the {r['ProductName']}! Priced at ${r['Price']} and famous for its {r['Description']}. "
50
+ f"What are your thoughts on this {r['Category']}?"),
51
+ ]
52
+ template = random.choice(templates)
53
  return template(row)
54
 
 
55
  def extract_fashion_filters(query):
56
  filters = {}
57
  query_lower = query.lower()
 
58
  if 'best' in query_lower and 'rating' in query_lower:
59
  filters['Rating'] = 'max'
 
60
  if 'dresses' in query_lower:
61
  filters['Category'] = 'dress'
62
  elif 'shoes' in query_lower:
63
  filters['Category'] = 'shoes'
 
64
  return filters
65
 
 
66
  def apply_fashion_filters(df, filters):
67
  for key, value in filters.items():
68
  if key == 'Rating' and value == 'max':
 
71
  df = df[df[key].str.contains(value, case=False, na=False)]
72
  return df
73
 
 
74
  def query_fashion(user_query, n_results=5):
75
  fashion_df = load_fashion_dataset()
76
  filtered_df = apply_fashion_filters(fashion_df, extract_fashion_filters(user_query))
 
 
77
  if 'Rating' in filtered_df.columns:
78
  sorted_df = filtered_df.sort_values(by='Rating', ascending=False)
79
  else:
80
+ sorted_df = filtered_df
 
 
81
  return sorted_df.head(n_results)
82
 
83
+ def fashion_agent_response(user_query):
84
+ system_message = """You are a knowledgeable fashion agent. Your responsibilities include:
85
+ 1. Handling all fashion-related queries.
86
+ 2. Providing information about clothing, accessories, and trends.
87
+ 3. Assisting with product details, styles, and prices.
88
+ 4. Offering fashion recommendations based on user preferences."""
89
+
90
+ predicted_class = predict_class(user_query, system_message)
91
+ responses = {
92
+ 0: "Information about fashion trends.",
93
+ 1: "Details about clothing styles and prices.",
94
+ 2: "Assistance with fashion recommendations."
95
+ }
96
 
97
+ response = responses.get(predicted_class, "I am not sure how to help with that.")
98
+ fashion_results = query_fashion(user_query)
99
  if not fashion_results.empty:
 
100
  for _, row in fashion_results.iterrows():
101
+ response += "\n" + generate_fashion_response(row)
102
+ else:
103
+ response += "\nSorry, I couldn't find any fashion items matching your query."
104
 
105
+ return response