Spaces:
Sleeping
Sleeping
| import os | |
| import uuid | |
| import logging | |
| from pathlib import Path | |
| from typing import List, Optional | |
| from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.responses import HTMLResponse, JSONResponse, FileResponse | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from huggingface_hub import InferenceClient | |
| import fitz # PyMuPDF | |
| from PIL import Image | |
| import io | |
| import pandas as pd | |
| from docx import Document | |
| from pptx import Presentation | |
| import json | |
| # Configuration du logging | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # Initialisation de l'application FastAPI | |
| app = FastAPI() | |
| # Configuration CORS | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"], | |
| allow_headers=["*"], | |
| allow_credentials=True, | |
| ) | |
| # Chemins des fichiers | |
| BASE_DIR = Path(__file__).parent | |
| UPLOAD_FOLDER = BASE_DIR / "uploads" | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| # Configuration des modèles Hugging Face | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| client = InferenceClient(token=HF_TOKEN) | |
| MODELS = { | |
| "summary": "facebook/bart-large-cnn", | |
| "caption": "Salesforce/blip-image-captioning-large", | |
| "qa": "distilbert-base-cased-distilled-squad" # plus léger | |
| } | |
| # Modèles Pydantic | |
| class FileInfo(BaseModel): | |
| file_id: str | |
| file_name: str | |
| file_type: str | |
| file_path: str | |
| extracted_text: Optional[str] = None | |
| class SummaryRequest(BaseModel): | |
| file_id: str | |
| max_length: int = 150 | |
| class CaptionRequest(BaseModel): | |
| file_id: str | |
| class QARequest(BaseModel): | |
| file_id: Optional[str] = None | |
| question: str | |
| # Fonctions utilitaires | |
| def extract_text_from_pdf(file_path: str) -> str: | |
| try: | |
| doc = fitz.open(file_path) | |
| return "\n".join([page.get_text() for page in doc]) | |
| except Exception as e: | |
| logger.error(f"PDF extraction error: {e}") | |
| raise HTTPException(400, "Erreur d'extraction PDF") | |
| def extract_text_from_docx(file_path: str) -> str: | |
| try: | |
| doc = Document(file_path) | |
| return "\n".join([para.text for para in doc.paragraphs]) | |
| except Exception as e: | |
| logger.error(f"DOCX extraction error: {e}") | |
| raise HTTPException(400, "Erreur d'extraction DOCX") | |
| def extract_text_from_pptx(file_path: str) -> str: | |
| try: | |
| prs = Presentation(file_path) | |
| text = [] | |
| for slide in prs.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text"): | |
| text.append(shape.text) | |
| return "\n".join(text) | |
| except Exception as e: | |
| logger.error(f"PPTX extraction error: {e}") | |
| raise HTTPException(400, "Erreur d'extraction PPTX") | |
| def extract_text_from_excel(file_path: str) -> str: | |
| try: | |
| xls = pd.ExcelFile(file_path) | |
| text = [] | |
| for sheet_name in xls.sheet_names: | |
| df = pd.read_excel(file_path, sheet_name=sheet_name) | |
| text.append(f"Feuille: {sheet_name}\n{df.to_string()}") | |
| return "\n\n".join(text) | |
| except Exception as e: | |
| logger.error(f"Excel extraction error: {e}") | |
| raise HTTPException(400, "Erreur d'extraction Excel") | |
| async def process_uploaded_file(file: UploadFile) -> FileInfo: | |
| file_ext = Path(file.filename).suffix.lower() | |
| file_id = str(uuid.uuid4()) | |
| file_path = str(UPLOAD_FOLDER / f"{file_id}{file_ext}") | |
| with open(file_path, "wb") as buffer: | |
| buffer.write(await file.read()) | |
| text = "" | |
| if file_ext == ".pdf": | |
| text = extract_text_from_pdf(file_path) | |
| elif file_ext == ".docx": | |
| text = extract_text_from_docx(file_path) | |
| elif file_ext == ".pptx": | |
| text = extract_text_from_pptx(file_path) | |
| elif file_ext in (".xlsx", ".xls"): | |
| text = extract_text_from_excel(file_path) | |
| return FileInfo( | |
| file_id=file_id, | |
| file_name=file.filename, | |
| file_type=file_ext[1:], | |
| file_path=file_path, | |
| extracted_text=text if text else None | |
| ) | |
| # Routes de l'API | |
| async def test_api(): | |
| return {"status": "API working", "environment": "Hugging Face" if os.environ.get("HF_SPACE") else "Local"} | |
| async def api_root(): | |
| return {"status": "API is running"} | |
| async def upload_files(files: List[UploadFile] = File(...)): | |
| logger.info(f"Upload request received with {len(files)} files") | |
| try: | |
| processed_files = [] | |
| for file in files: | |
| processed_file = await process_uploaded_file(file) | |
| processed_files.append(processed_file) | |
| logger.info(f"Files processed successfully: {len(processed_files)}") | |
| return processed_files | |
| except Exception as e: | |
| logger.error(f"Upload error: {e}") | |
| raise HTTPException(500, f"Erreur lors de l'upload: {str(e)}") | |
| async def summarize_document(request: SummaryRequest): | |
| try: | |
| file_path = next(f for f in UPLOAD_FOLDER.glob(f"{request.file_id}*")) | |
| text = "" | |
| if file_path.suffix == ".pdf": | |
| text = extract_text_from_pdf(str(file_path)) | |
| else: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| text = f.read() | |
| summary = client.summarization( | |
| text=text[:5000], # limite si le document est trop long | |
| model=MODELS["summary"], | |
| parameters={"max_length": request.max_length} | |
| ) | |
| return {"summary": summary} | |
| except Exception as e: | |
| logger.error(f"Summarization error: {e}") | |
| raise HTTPException(500, f"Erreur de résumé: {str(e)}") | |
| async def caption_image(request: CaptionRequest): | |
| try: | |
| file_path = next(f for f in UPLOAD_FOLDER.glob(f"{request.file_id}*")) | |
| with open(file_path, "rb") as image_file: | |
| image_data = image_file.read() | |
| caption = client.image_to_text( | |
| image=image_data, | |
| model=MODELS["caption"] | |
| ) | |
| return {"caption": caption} | |
| except Exception as e: | |
| logger.error(f"Captioning error: {e}") | |
| raise HTTPException(500, f"Erreur de description: {str(e)}") | |
| async def answer_question(request: QARequest): | |
| try: | |
| context = "" | |
| if request.file_id: | |
| file_path = next(f for f in UPLOAD_FOLDER.glob(f"{request.file_id}*")) | |
| if file_path.suffix in (".jpg", ".jpeg", ".png"): | |
| with open(file_path, "rb") as image_file: | |
| image_data = image_file.read() | |
| context = client.image_to_text(image=image_data, model=MODELS["caption"]) | |
| else: | |
| if file_path.suffix == ".pdf": | |
| context = extract_text_from_pdf(str(file_path)) | |
| else: | |
| with open(file_path, "r", encoding="utf-8") as f: | |
| context = f.read() | |
| if not context: | |
| raise HTTPException(400, "Aucun contexte trouvé pour répondre à la question.") | |
| # Après l'appel | |
| raw_response = client.post( | |
| model=MODELS["qa"], | |
| json={ | |
| "inputs": { | |
| "question": request.question, | |
| "context": context | |
| } | |
| } | |
| ) | |
| # Décoder proprement | |
| response = json.loads(raw_response) | |
| return {"answer": response["answer"]} | |
| except Exception as e: | |
| logger.error(f"QA error: {e}") | |
| raise HTTPException(500, f"Erreur de réponse: {str(e)}") | |
| async def get_file(file_id: str): | |
| try: | |
| file_path = next(f for f in UPLOAD_FOLDER.glob(f"{file_id}*")) | |
| return FileResponse(file_path) | |
| except Exception as e: | |
| logger.error(f"File retrieval error: {e}") | |
| raise HTTPException(404, "Fichier non trouvé") | |
| # Gestion des erreurs globales | |
| async def http_exception_handler(request, exc): | |
| return JSONResponse( | |
| status_code=exc.status_code, | |
| content={"detail": exc.detail}, | |
| ) | |
| async def generic_exception_handler(request, exc): | |
| logger.error(f"Unhandled exception: {exc}") | |
| return JSONResponse( | |
| status_code=500, | |
| content={"detail": "Une erreur interne est survenue"}, | |
| ) | |
| # Montage des fichiers statiques | |
| app.mount("/", StaticFiles(directory=BASE_DIR, html=True), name="static") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True) | |