Spaces:
Sleeping
Sleeping
File size: 5,811 Bytes
fe53425 e17b15a fe53425 e17b15a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
from fastapi import FastAPI, Query
from typing import List, Dict, Any
from app import database
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
app = FastAPI(
title="Articles API",
description="API pour interroger la base articles",
version="1.0"
)
# CORS pour permettre l'accès depuis le navigateur
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # autorise toutes les origines
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/", response_class=HTMLResponse)
def home():
return """
<html>
<head><title>Page d'accueil</title></head>
<body>
<h1>Welcome on the API search articles !</h1>
</body>
</html>
"""
@app.get("/get_tags")
def get_tags():
"""
Récupère la liste de tous les tags disponibles via l'API.
Returns:
Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur.
- Si succès :
{
"status": "ok",
"tags": List[str] # Liste des noms de tags triés par ordre alphabétique
}
- En cas d'erreur :
{
"status": "error",
"code": str, # Nom de l'exception
"message": str # Message de l'exception
}
Notes:
- L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`.
- En cas de problème avec la base de données, un message d'erreur détaillé est retourné.
"""
try:
dict_result = database.fetch_tags()
if dict_result["status"] == "ok":
return {"status": "ok", "tags": dict_result["result"]}
else:
return dict_result
except Exception as e:
return {"status": "error", "code": type(e).__name__, "message": str(e)}
@app.get("/get_articles_with_tags")
def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")):
"""
Récupère les articles associés à une ou plusieurs tags spécifiés.
Args:
tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles.
Doit contenir au moins un tag.
Returns:
Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur.
- Si succès :
{
"status": "ok",
"tags": List[str], # Tags utilisés pour filtrer
"articles": List[Dict] # Liste des articles correspondants
}
Chaque article est un dictionnaire contenant :
- 'article_id': int, ID de l'article
- 'article_title': str, Titre de l'article
- 'article_url': str, URL de l'article
- En cas d'erreur :
{
"status": "error",
"code": str, # Code d'erreur ou nom de l'exception
"message": str # Message d'erreur
}
Notes:
- Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags".
- L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles.
"""
try:
dict_result = database.fetch_articles_by_tags(tags)
if dict_result["status"] == "ok":
return {"status": "ok",
"tags": tags,
"articles": dict_result["result"]}
else:
return dict_result
except Exception as e:
return {"status": "error", "code": type(e).__name__, "message": str(e)}
@app.get("/get_query_results")
def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"),
k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"),
k_cross: int = Query(5, description="Nombre de résultats conservés après reranking"),
use_rerank: bool = Query(True, description="Indique si le reranking avec cross-encoder doit être utilisé")
) -> Dict[str, Any]:
"""
Récupère les résultats d'une requête en utilisant deux modèles de recherche.
Args:
query (str): La requête utilisateur pour laquelle récupérer les résultats.
k_model (int, optionel): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10.
k_cross (int, optionel): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5.
use_rerank (bool, optionnel): Indique si le reranking avec cross-encoder doit être utilisé. Par défaut à True.
Si False, on désactive complètement le cross-encoder et le rerank.
Returns:
Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur.
Notes:
- L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats.
- En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné.
"""
try:
dict_result = database.fetch_query_results(query, k_model, k_cross, use_rerank)
if dict_result["status"] == "ok":
return {"status": "ok",
"results": dict_result["result"]}
else:
return dict_result
except Exception as e:
return {"status": "error", "code": type(e).__name__, "message": str(e)}
# |