salim4n commited on
Commit
2bbd95d
·
verified ·
1 Parent(s): 3ddb952

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +72 -0
app.py CHANGED
@@ -5,6 +5,11 @@ from smolagents import CodeAgent, OpenAIServerModel
5
  import os
6
  from dotenv import load_dotenv
7
  from sql_data import sql_query, get_schema, get_csv_as_dataframe
 
 
 
 
 
8
 
9
  # Load environment variables
10
  load_dotenv()
@@ -20,6 +25,73 @@ model = OpenAIServerModel(model_id=model_id, api_key=os.environ["OPENAI_API_KEY"
20
 
21
  agent = CodeAgent(tools=[sql_query, get_schema, get_csv_as_dataframe], model=model)
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  def run_agent(question: str) -> str:
24
  """
25
  Run the agent with the given question.
 
5
  import os
6
  from dotenv import load_dotenv
7
  from sql_data import sql_query, get_schema, get_csv_as_dataframe
8
+ import pandas as pd
9
+ from sqlalchemy import create_engine, text
10
+
11
+ # Create database engine
12
+ engine = create_engine("sqlite:///freights.db")
13
 
14
  # Load environment variables
15
  load_dotenv()
 
25
 
26
  agent = CodeAgent(tools=[sql_query, get_schema, get_csv_as_dataframe], model=model)
27
 
28
+ def sql_query(query: str) -> str:
29
+ """
30
+ Allows you to perform SQL queries on the freights table. Returns a string representation of the result.
31
+ The table is named 'freights'. Its description is as follows:
32
+ Columns:
33
+ - departure: DateTime (Date and time of departure)
34
+ - origin_port_locode: String (Origin port code)
35
+ - origin_port_name: String (Name of the origin port)
36
+ - destination_port: String (Destination port code)
37
+ - destination_port_name: String (Name of the destination port)
38
+ - dv20rate: Float (Rate for 20ft container in USD)
39
+ - dv40rate: Float (Rate for 40ft container in USD)
40
+ - currency: String (Currency of the rates)
41
+ - inserted_on: DateTime (Date when the rate was inserted)
42
+ Args:
43
+ query: The query to perform. This should be correct SQL.
44
+ Returns:
45
+ A string representation of the result of the query.
46
+ """
47
+ try:
48
+ with engine.connect() as con:
49
+ result = con.execute(text(query))
50
+ rows = [dict(row._mapping) for row in result]
51
+
52
+ if not rows:
53
+ return "Aucun résultat trouvé."
54
+
55
+ # Convert to markdown table
56
+ headers = list(rows[0].keys())
57
+ table = "| " + " | ".join(headers) + " |\n"
58
+ table += "| " + " | ".join(["---" for _ in headers]) + " |\n"
59
+
60
+ for row in rows:
61
+ table += "| " + " | ".join(str(row[h]) for h in headers) + " |\n"
62
+
63
+ return table
64
+
65
+ except Exception as e:
66
+ return f"Error executing query: {str(e)}"
67
+
68
+
69
+ def get_schema() -> str:
70
+ """
71
+ Returns the schema of the freights table.
72
+ """
73
+ return """
74
+ Table: freights
75
+ Columns:
76
+ - departure: DateTime (Date and time of departure)
77
+ - origin_port_locode: String (Origin port code)
78
+ - origin_port_name: String (Name of the origin port)
79
+ - destination_port: String (Destination port code)
80
+ - destination_port_name: String (Name of the destination port)
81
+ - dv20rate: Float (Rate for 20ft container in USD)
82
+ - dv40rate: Float (Rate for 40ft container in USD)
83
+ - currency: String (Currency of the rates)
84
+ - inserted_on: DateTime (Date when the rate was inserted)
85
+ """
86
+
87
+
88
+ def get_csv_as_dataframe() -> str:
89
+ """
90
+ Returns a string representation of the freights table as a CSV file.
91
+ """
92
+ df = pd.read_sql_table("freights", engine)
93
+ return df.to_csv(index=False)
94
+
95
  def run_agent(question: str) -> str:
96
  """
97
  Run the agent with the given question.