Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, Query | |
| from typing import List, Dict, Any | |
| from app import database | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse | |
| from pydantic import BaseModel | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| import torch | |
| app = FastAPI( | |
| title="Articles API", | |
| description="API pour interroger la base articles", | |
| version="1.0" | |
| ) | |
| # CORS pour permettre l'accès depuis le navigateur | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # autorise toutes les origines | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def home(): | |
| return """ | |
| <html> | |
| <head><title>Page d'accueil</title></head> | |
| <body> | |
| <h1>Welcome on the API search articles !</h1> | |
| </body> | |
| </html> | |
| """ | |
| def get_tags(): | |
| """ | |
| Récupère la liste de tous les tags disponibles via l'API. | |
| Returns: | |
| Dict: Un dictionnaire contenant soit la liste des tags, soit les informations d'erreur. | |
| - Si succès : | |
| { | |
| "status": "ok", | |
| "tags": List[str] # Liste des noms de tags triés par ordre alphabétique | |
| } | |
| - En cas d'erreur : | |
| { | |
| "status": "error", | |
| "code": str, # Nom de l'exception | |
| "message": str # Message de l'exception | |
| } | |
| Notes: | |
| - L'appel de cet endpoint effectue un accès à la base de données via la fonction `fetch_tags`. | |
| - En cas de problème avec la base de données, un message d'erreur détaillé est retourné. | |
| """ | |
| try: | |
| dict_result = database.fetch_tags() | |
| if dict_result["status"] == "ok": | |
| return {"status": "ok", "tags": dict_result["result"]} | |
| else: | |
| return dict_result | |
| except Exception as e: | |
| return {"status": "error", "code": type(e).__name__, "message": str(e)} | |
| def get_articles_with_tags(tags: List[str] = Query(..., description="Liste des tags à filtrer")): | |
| """ | |
| Récupère les articles associés à une ou plusieurs tags spécifiés. | |
| Args: | |
| tags (List[str]): Liste des noms de tags utilisés pour filtrer les articles. | |
| Doit contenir au moins un tag. | |
| Returns: | |
| Dict: Un dictionnaire contenant soit les articles correspondants, soit les informations d'erreur. | |
| - Si succès : | |
| { | |
| "status": "ok", | |
| "tags": List[str], # Tags utilisés pour filtrer | |
| "articles": List[Dict] # Liste des articles correspondants | |
| } | |
| Chaque article est un dictionnaire contenant : | |
| - 'article_id': int, ID de l'article | |
| - 'article_title': str, Titre de l'article | |
| - 'article_url': str, URL de l'article | |
| - En cas d'erreur : | |
| { | |
| "status": "error", | |
| "code": str, # Code d'erreur ou nom de l'exception | |
| "message": str # Message d'erreur | |
| } | |
| Notes: | |
| - Si la liste `tags` est vide, la fonction retourne une erreur avec le code "no_tags". | |
| - L'appel de cet endpoint utilise la fonction `fetch_articles_by_tags` pour récupérer les articles. | |
| """ | |
| try: | |
| dict_result = database.fetch_articles_by_tags(tags) | |
| if dict_result["status"] == "ok": | |
| return {"status": "ok", | |
| "tags": tags, | |
| "articles": dict_result["result"]} | |
| else: | |
| return dict_result | |
| except Exception as e: | |
| return {"status": "error", "code": type(e).__name__, "message": str(e)} | |
| def get_query_results(query: str = Query(..., description="Requête de recherche textuelle"), | |
| k_model: int = Query(10, description="Nombre de candidats retournés par FAISS"), | |
| k_cross: int = Query(5, description="Nombre de résultats conservés après reranking"), | |
| use_rerank: bool = Query(True, description="Indique si le reranking avec cross-encoder doit être utilisé") | |
| ) -> Dict[str, Any]: | |
| """ | |
| Récupère les résultats d'une requête en utilisant deux modèles de recherche. | |
| Args: | |
| query (str): La requête utilisateur pour laquelle récupérer les résultats. | |
| k_model (int, optionel): Nombre de résultats à retourner pour le modèle principal. Par défaut à 10. | |
| k_cross (int, optionel): Nombre de résultats à retourner pour le modèle croisé. Par défaut à 5. | |
| use_rerank (bool, optionnel): Indique si le reranking avec cross-encoder doit être utilisé. Par défaut à True. | |
| Si False, on désactive complètement le cross-encoder et le rerank. | |
| Returns: | |
| Dict[str, Any]: Un dictionnaire contenant soit les résultats de la requête, soit les informations d'erreur. | |
| Notes: | |
| - L'appel de cet endpoint utilise la fonction `fetch_query_result` pour obtenir les résultats. | |
| - En cas de problème lors du traitement de la requête, un message d'erreur détaillé est retourné. | |
| """ | |
| try: | |
| dict_result = database.fetch_query_results(query, k_model, k_cross, use_rerank) | |
| if dict_result["status"] == "ok": | |
| return {"status": "ok", | |
| "results": dict_result["result"]} | |
| else: | |
| return dict_result | |
| except Exception as e: | |
| return {"status": "error", "code": type(e).__name__, "message": str(e)} | |
| # |