Loren commited on
Commit
8288368
·
verified ·
1 Parent(s): e7896b3

Update app/database.py

Browse files
Files changed (1) hide show
  1. app/database.py +60 -56
app/database.py CHANGED
@@ -1,56 +1,60 @@
1
- import sqlite3
2
- from typing import List, Dict
3
- import os
4
- from huggingface_hub import hf_hub_download
5
-
6
- # Télécharger le fichier SQLite depuis le dataset
7
- REPO_ID = "Loren/articles_db" # dataset HF
8
- DB_NAME = 'articles.db'
9
- hf_token = os.environ["API_HF_TOKEN"]
10
- sqlite_path = hf_hub_download(
11
- repo_id=REPO_ID,
12
- filename=DB_NAME,
13
- repo_type="dataset",
14
- token=hf_token
15
- )
16
-
17
- def get_connection(sqlite_path):
18
- conn = sqlite3.connect(sqlite_path)
19
- conn.row_factory = sqlite3.Row
20
- return conn
21
-
22
- def fetch_tags() -> List[str]:
23
- """Retourne tous les tags"""
24
- conn = get_connection()
25
- cur = conn.cursor()
26
- cur.execute("SELECT tag_name FROM tags ORDER BY tag_name")
27
- tags = [row["tag_name"] for row in cur.fetchall()]
28
- conn.close()
29
- return tags
30
-
31
-
32
- def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
33
- """
34
- Retourne les articles correspondant aux tags.
35
- """
36
- if not tags:
37
- return []
38
-
39
- conn = get_connection()
40
- conn.row_factory = sqlite3.Row
41
- cur = conn.cursor()
42
-
43
- # Créer la liste de placeholders "?" dynamiquement
44
- placeholders = ",".join(["?"] * len(tags))
45
-
46
- query = ("""SELECT a.article_id, a.article_title, a.article_url
47
- FROM tags t, articles a, tag_article ta
48
- WHERE ta.tag_id = t.tag_id
49
- AND ta.article_id = a.article_id
50
- AND t.tag_name IN (""" + placeholders + """)"""
51
- )
52
-
53
- cur.execute(query, tags)
54
- results = [dict(row) for row in cur.fetchall()]
55
- conn.close()
56
- return results
 
 
 
 
 
1
+ import sqlite3
2
+ from typing import List, Dict
3
+ import os
4
+ from huggingface_hub import hf_hub_download
5
+
6
+ # Télécharger le fichier SQLite depuis le dataset
7
+ # Créer un dossier temporaire pour le cache
8
+ cache_dir = "/tmp/hf_cache"
9
+ os.makedirs(cache_dir, exist_ok=True)
10
+ REPO_ID = "Loren/articles_db" # dataset HF
11
+ DB_NAME = 'articles.db'
12
+ hf_token = os.environ["API_HF_TOKEN"]
13
+ sqlite_path = hf_hub_download(
14
+ repo_id=REPO_ID,
15
+ filename=DB_NAME,
16
+ repo_type="dataset",
17
+ token=hf_token,
18
+ cache_dir=cache_dir
19
+ )
20
+
21
+ def get_connection(sqlite_path):
22
+ conn = sqlite3.connect(sqlite_path)
23
+ conn.row_factory = sqlite3.Row
24
+ return conn
25
+
26
+ def fetch_tags() -> List[str]:
27
+ """Retourne tous les tags"""
28
+ conn = get_connection()
29
+ cur = conn.cursor()
30
+ cur.execute("SELECT tag_name FROM tags ORDER BY tag_name")
31
+ tags = [row["tag_name"] for row in cur.fetchall()]
32
+ conn.close()
33
+ return tags
34
+
35
+
36
+ def fetch_articles_by_tags(tags: List[str]) -> List[Dict]:
37
+ """
38
+ Retourne les articles correspondant aux tags.
39
+ """
40
+ if not tags:
41
+ return []
42
+
43
+ conn = get_connection()
44
+ conn.row_factory = sqlite3.Row
45
+ cur = conn.cursor()
46
+
47
+ # Créer la liste de placeholders "?" dynamiquement
48
+ placeholders = ",".join(["?"] * len(tags))
49
+
50
+ query = ("""SELECT a.article_id, a.article_title, a.article_url
51
+ FROM tags t, articles a, tag_article ta
52
+ WHERE ta.tag_id = t.tag_id
53
+ AND ta.article_id = a.article_id
54
+ AND t.tag_name IN (""" + placeholders + """)"""
55
+ )
56
+
57
+ cur.execute(query, tags)
58
+ results = [dict(row) for row in cur.fetchall()]
59
+ conn.close()
60
+ return results