SHAMIL SHAHBAZ AWAN commited on
Commit
559d037
·
verified ·
1 Parent(s): 9dd7961

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -0
app.py CHANGED
@@ -4,6 +4,10 @@ import matplotlib.pyplot as plt
4
  import seaborn as sns
5
  import os
6
  from io import StringIO
 
 
 
 
7
 
8
  # Function to load the uploaded file (CSV or Excel)
9
  def load_file(uploaded_file):
@@ -29,6 +33,14 @@ def infer_column(data, synonyms):
29
  return column
30
  return None
31
 
 
 
 
 
 
 
 
 
32
  # Function to generate graph based on user query
33
  def generate_graph(data, query):
34
  """Generate a graph based on user query."""
@@ -40,6 +52,8 @@ def generate_graph(data, query):
40
  sales_col = infer_column(data, {"gross_sales", "sales", "revenue"})
41
  date_col = infer_column(data, {"date", "time"})
42
 
 
 
43
  if "bar" in query.lower() and country_col and sales_col:
44
  # Bar chart for countries and gross sales
45
  country_data = data[[country_col, sales_col]].groupby(country_col).sum().reset_index()
 
4
  import seaborn as sns
5
  import os
6
  from io import StringIO
7
+ from transformers import pipeline
8
+
9
+ # Load a lightweight NLP model for query understanding
10
+ nlp = pipeline("text-classification", model="distilbert-base-uncased", tokenizer="distilbert-base-uncased")
11
 
12
  # Function to load the uploaded file (CSV or Excel)
13
  def load_file(uploaded_file):
 
33
  return column
34
  return None
35
 
36
+ # Function to classify the user query
37
+ def classify_query(query):
38
+ """Classify the user query into graph types."""
39
+ results = nlp(query)
40
+ if results:
41
+ return results[0]['label']
42
+ return None
43
+
44
  # Function to generate graph based on user query
45
  def generate_graph(data, query):
46
  """Generate a graph based on user query."""
 
52
  sales_col = infer_column(data, {"gross_sales", "sales", "revenue"})
53
  date_col = infer_column(data, {"date", "time"})
54
 
55
+ query_type = classify_query(query)
56
+
57
  if "bar" in query.lower() and country_col and sales_col:
58
  # Bar chart for countries and gross sales
59
  country_data = data[[country_col, sales_col]].groupby(country_col).sum().reset_index()