salim4n commited on
Commit
8cb747c
·
verified ·
1 Parent(s): c3243b2

Upload 2 files

Browse files
Files changed (2) hide show
  1. tools.py +42 -0
  2. utils.py +79 -0
tools.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from smolagents import CodeAgent, HfApiModel
2
+ from sql_data import sql_query, get_schema
3
+ from sqlalchemy import create_engine, inspect, text
4
+ import os
5
+ from dotenv import load_dotenv
6
+ from typing import Dict, List, Any
7
+ import json
8
+
9
+ # Load environment variables
10
+ load_dotenv()
11
+
12
+ # Example queries that the agent can handle
13
+ EXAMPLE_QUERIES = [
14
+ "Quels sont les tarifs moyens des conteneurs 20ft et 40ft entre tous les ports ?",
15
+ "Quels sont les ports d'origine les plus fréquents ?",
16
+ "Montre-moi les routes avec des tarifs élevés pour les conteneurs 40ft",
17
+ "Quelle est l'évolution des prix au fil du temps pour la route Surabaya vers Nansha ?",
18
+ "Quelles sont les destinations disponibles depuis Shanghai ?",
19
+ ]
20
+
21
+
22
+ class FreightAgent:
23
+ def __init__(self):
24
+ self.setup_agent()
25
+
26
+ def setup_agent(self) -> None:
27
+ """
28
+ Initialize the CodeAgent with SQL tools.
29
+ Create a CodeAgent with two tools: `sql_query` and `get_schema`.
30
+ `sql_query` allows to perform SQL queries on the freights table.
31
+ `get_schema` returns the schema of the freights table.
32
+ """
33
+ self.agent = CodeAgent(
34
+ tools=[sql_query, get_schema],
35
+ model=HfApiModel("meta-llama/Llama-3.1-8B-Instruct"),
36
+ )
37
+
38
+ def query(self, question: str) -> str:
39
+ """
40
+ Ask a question about the freight data in natural language
41
+ """
42
+ return self.agent.run(question)
utils.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import sqlite3
3
+ import requests
4
+ from typing import List, Dict, Any
5
+ import os
6
+ from sqlalchemy import create_engine, Column, Float, String, Integer, DateTime
7
+ from sqlalchemy.ext.declarative import declarative_base
8
+ from sqlalchemy.orm import sessionmaker
9
+
10
+ # Create base class for declarative models
11
+ Base = declarative_base()
12
+
13
+
14
+ class Freight(Base):
15
+ """SQLAlchemy model for freight data"""
16
+
17
+ __tablename__ = "freights"
18
+
19
+ id = Column(Integer, primary_key=True)
20
+ departure = Column(DateTime)
21
+ origin_port_locode = Column(String)
22
+ origin_port_name = Column(String)
23
+ destination_port = Column(String)
24
+ destination_port_name = Column(String)
25
+ dv20rate = Column(Float)
26
+ dv40rate = Column(Float)
27
+ currency = Column(String)
28
+ inserted_on = Column(DateTime)
29
+
30
+
31
+ def download_csv(url: str, local_path: str = "freights.csv") -> str:
32
+ """
33
+ Download CSV file from Hugging Face and save it locally
34
+ """
35
+ response = requests.get(url)
36
+ with open(local_path, "wb") as f:
37
+ f.write(response.content)
38
+ return local_path
39
+
40
+
41
+ def create_database(db_name: str = "freights.db") -> None:
42
+ """
43
+ Create SQLite database and necessary tables
44
+ """
45
+ engine = create_engine(f"sqlite:///{db_name}")
46
+ Base.metadata.create_all(engine)
47
+
48
+
49
+ def load_csv_to_db(csv_path: str, db_name: str = "freights.db") -> None:
50
+ """
51
+ Load CSV data into SQLite database
52
+ """
53
+ # Read CSV
54
+ df = pd.read_csv(csv_path, parse_dates=["departure", "inserted_on"])
55
+
56
+ # Connect to database
57
+ engine = create_engine(f"sqlite:///{db_name}")
58
+
59
+ # Save to database
60
+ df.to_sql("freights", engine, if_exists="replace", index=False)
61
+
62
+
63
+ def initialize_database(csv_url: str) -> None:
64
+ """
65
+ Initialize the database by downloading CSV and loading data.
66
+ Args:
67
+ csv_url: URL of the CSV file to download and load.
68
+ """
69
+ # Download CSV
70
+ csv_path = download_csv(csv_url)
71
+
72
+ # Create and load database
73
+ create_database()
74
+ load_csv_to_db(csv_path)
75
+ print("Database initialized.")
76
+
77
+ # Clean up CSV file
78
+ if os.path.exists(csv_path):
79
+ os.remove(csv_path)