SamiKLN commited on
Commit
387f195
·
verified ·
1 Parent(s): 2dcbd0b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +74 -33
main.py CHANGED
@@ -9,7 +9,6 @@ from fastapi.responses import HTMLResponse, JSONResponse, FileResponse
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
11
  from huggingface_hub import InferenceClient
12
- from transformers import pipeline # Pour le pipeline QA
13
  import fitz # PyMuPDF
14
  from PIL import Image
15
  import io
@@ -21,22 +20,16 @@ from pptx import Presentation
21
  logging.basicConfig(level=logging.INFO)
22
  logger = logging.getLogger(__name__)
23
 
24
- # Configuration du cache pour Hugging Face
25
- TEMP_CACHE_DIR = "/tmp/huggingface_cache"
26
- os.environ["TRANSFORMERS_CACHE"] = TEMP_CACHE_DIR
27
- os.environ["HF_HOME"] = TEMP_CACHE_DIR
28
- Path(TEMP_CACHE_DIR).mkdir(parents=True, exist_ok=True)
29
-
30
  # Initialisation de l'application FastAPI
31
  app = FastAPI()
32
 
33
- # Configuration CORS avec méthodes explicites
34
  app.add_middleware(
35
  CORSMiddleware,
36
  allow_origins=["*"],
37
- allow_methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"], # Méthodes explicites
38
  allow_headers=["*"],
39
- allow_credentials=True, # Permettre les credentials si nécessaire
40
  )
41
 
42
  # Chemins des fichiers
@@ -44,18 +37,15 @@ BASE_DIR = Path(__file__).parent
44
  UPLOAD_FOLDER = BASE_DIR / "uploads"
45
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
46
 
47
- # Modèles Hugging Face
48
  HF_TOKEN = os.getenv("HF_TOKEN")
49
  client = InferenceClient(token=HF_TOKEN)
50
  MODELS = {
51
  "summary": "facebook/bart-large-cnn",
52
  "caption": "Salesforce/blip-image-captioning-large",
53
- "qa": "distilbert-base-cased-distilled-squad" # Modèle QA léger
54
  }
55
 
56
- # Pipeline QA pour distilbert-base-cased-distilled-squad
57
- qa_pipeline = pipeline("question-answering", model=MODELS["qa"], tokenizer=MODELS["qa"], cache_dir=TEMP_CACHE_DIR)
58
-
59
  # Modèles Pydantic
60
  class FileInfo(BaseModel):
61
  file_id: str
@@ -75,7 +65,7 @@ class QARequest(BaseModel):
75
  file_id: Optional[str] = None
76
  question: str
77
 
78
- # Fonctions utilitaires (extraction de texte, etc.)
79
  def extract_text_from_pdf(file_path: str) -> str:
80
  try:
81
  doc = fitz.open(file_path)
@@ -122,11 +112,9 @@ async def process_uploaded_file(file: UploadFile) -> FileInfo:
122
  file_id = str(uuid.uuid4())
123
  file_path = str(UPLOAD_FOLDER / f"{file_id}{file_ext}")
124
 
125
- # Sauvegarde du fichier
126
  with open(file_path, "wb") as buffer:
127
  buffer.write(await file.read())
128
 
129
- # Extraction du texte selon le type de fichier
130
  text = ""
131
  if file_ext == ".pdf":
132
  text = extract_text_from_pdf(file_path)
@@ -145,12 +133,15 @@ async def process_uploaded_file(file: UploadFile) -> FileInfo:
145
  extracted_text=text if text else None
146
  )
147
 
148
- # Route de test pour vérifier l'API
149
  @app.get("/api/test")
150
  async def test_api():
151
  return {"status": "API working", "environment": "Hugging Face" if os.environ.get("HF_SPACE") else "Local"}
152
 
153
- # Endpoints (upload, summarize, caption, answer, etc.)
 
 
 
154
  @app.post("/api/upload")
155
  async def upload_files(files: List[UploadFile] = File(...)):
156
  logger.info(f"Upload request received with {len(files)} files")
@@ -165,6 +156,47 @@ async def upload_files(files: List[UploadFile] = File(...)):
165
  logger.error(f"Upload error: {e}")
166
  raise HTTPException(500, f"Erreur lors de l'upload: {str(e)}")
167
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  @app.post("/api/answer")
169
  async def answer_question(request: QARequest):
170
  try:
@@ -182,22 +214,31 @@ async def answer_question(request: QARequest):
182
  else:
183
  with open(file_path, "r", encoding="utf-8") as f:
184
  context = f.read()
185
-
186
- # Utiliser le pipeline QA pour obtenir la réponse
187
- result = qa_pipeline(question=request.question, context=context)
188
 
189
- return {
190
- "answer": result["answer"],
191
- "confidence": result["score"]
192
- }
193
- except StopIteration:
194
- logger.error("File not found")
195
- raise HTTPException(404, "Fichier non trouvé")
 
 
 
196
  except Exception as e:
197
  logger.error(f"QA error: {e}")
198
  raise HTTPException(500, f"Erreur de réponse: {str(e)}")
199
 
200
- # Gestion des erreurs
 
 
 
 
 
 
 
 
 
201
  @app.exception_handler(HTTPException)
202
  async def http_exception_handler(request, exc):
203
  return JSONResponse(
@@ -213,9 +254,9 @@ async def generic_exception_handler(request, exc):
213
  content={"detail": "Une erreur interne est survenue"},
214
  )
215
 
216
- # Montage des fichiers statiques APRÈS la définition des routes API
217
  app.mount("/", StaticFiles(directory=BASE_DIR, html=True), name="static")
218
 
219
  if __name__ == "__main__":
220
  import uvicorn
221
- uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
9
  from fastapi.middleware.cors import CORSMiddleware
10
  from pydantic import BaseModel
11
  from huggingface_hub import InferenceClient
 
12
  import fitz # PyMuPDF
13
  from PIL import Image
14
  import io
 
20
  logging.basicConfig(level=logging.INFO)
21
  logger = logging.getLogger(__name__)
22
 
 
 
 
 
 
 
23
  # Initialisation de l'application FastAPI
24
  app = FastAPI()
25
 
26
+ # Configuration CORS
27
  app.add_middleware(
28
  CORSMiddleware,
29
  allow_origins=["*"],
30
+ allow_methods=["POST", "GET", "PUT", "DELETE", "OPTIONS"],
31
  allow_headers=["*"],
32
+ allow_credentials=True,
33
  )
34
 
35
  # Chemins des fichiers
 
37
  UPLOAD_FOLDER = BASE_DIR / "uploads"
38
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
39
 
40
+ # Configuration des modèles Hugging Face
41
  HF_TOKEN = os.getenv("HF_TOKEN")
42
  client = InferenceClient(token=HF_TOKEN)
43
  MODELS = {
44
  "summary": "facebook/bart-large-cnn",
45
  "caption": "Salesforce/blip-image-captioning-large",
46
+ "qa": "istilbert-base-cased-distilled-squad" # plus léger
47
  }
48
 
 
 
 
49
  # Modèles Pydantic
50
  class FileInfo(BaseModel):
51
  file_id: str
 
65
  file_id: Optional[str] = None
66
  question: str
67
 
68
+ # Fonctions utilitaires
69
  def extract_text_from_pdf(file_path: str) -> str:
70
  try:
71
  doc = fitz.open(file_path)
 
112
  file_id = str(uuid.uuid4())
113
  file_path = str(UPLOAD_FOLDER / f"{file_id}{file_ext}")
114
 
 
115
  with open(file_path, "wb") as buffer:
116
  buffer.write(await file.read())
117
 
 
118
  text = ""
119
  if file_ext == ".pdf":
120
  text = extract_text_from_pdf(file_path)
 
133
  extracted_text=text if text else None
134
  )
135
 
136
+ # Routes de l'API
137
  @app.get("/api/test")
138
  async def test_api():
139
  return {"status": "API working", "environment": "Hugging Face" if os.environ.get("HF_SPACE") else "Local"}
140
 
141
+ @app.get("/api")
142
+ async def api_root():
143
+ return {"status": "API is running"}
144
+
145
  @app.post("/api/upload")
146
  async def upload_files(files: List[UploadFile] = File(...)):
147
  logger.info(f"Upload request received with {len(files)} files")
 
156
  logger.error(f"Upload error: {e}")
157
  raise HTTPException(500, f"Erreur lors de l'upload: {str(e)}")
158
 
159
+ @app.post("/api/summarize")
160
+ async def summarize_document(request: SummaryRequest):
161
+ try:
162
+ file_path = next(f for f in UPLOAD_FOLDER.glob(f"{request.file_id}*"))
163
+ text = ""
164
+
165
+ if file_path.suffix == ".pdf":
166
+ text = extract_text_from_pdf(str(file_path))
167
+ else:
168
+ with open(file_path, "r", encoding="utf-8") as f:
169
+ text = f.read()
170
+
171
+ summary = client.summarization(
172
+ text=text[:5000], # limite si le document est trop long
173
+ model=MODELS["summary"],
174
+ parameters={"max_length": request.max_length}
175
+ )
176
+
177
+ return {"summary": summary}
178
+ except Exception as e:
179
+ logger.error(f"Summarization error: {e}")
180
+ raise HTTPException(500, f"Erreur de résumé: {str(e)}")
181
+
182
+ @app.post("/api/caption")
183
+ async def caption_image(request: CaptionRequest):
184
+ try:
185
+ file_path = next(f for f in UPLOAD_FOLDER.glob(f"{request.file_id}*"))
186
+
187
+ with open(file_path, "rb") as image_file:
188
+ image_data = image_file.read()
189
+
190
+ caption = client.image_to_text(
191
+ image=image_data,
192
+ model=MODELS["caption"]
193
+ )
194
+
195
+ return {"caption": caption}
196
+ except Exception as e:
197
+ logger.error(f"Captioning error: {e}")
198
+ raise HTTPException(500, f"Erreur de description: {str(e)}")
199
+
200
  @app.post("/api/answer")
201
  async def answer_question(request: QARequest):
202
  try:
 
214
  else:
215
  with open(file_path, "r", encoding="utf-8") as f:
216
  context = f.read()
 
 
 
217
 
218
+ if not context:
219
+ raise HTTPException(400, "Aucun contexte trouvé pour répondre à la question.")
220
+
221
+ response = client.question_answering(
222
+ question=request.question,
223
+ context=context,
224
+ model=MODELS["qa"]
225
+ )
226
+
227
+ return {"answer": response}
228
  except Exception as e:
229
  logger.error(f"QA error: {e}")
230
  raise HTTPException(500, f"Erreur de réponse: {str(e)}")
231
 
232
+ @app.get("/api/file/{file_id}")
233
+ async def get_file(file_id: str):
234
+ try:
235
+ file_path = next(f for f in UPLOAD_FOLDER.glob(f"{file_id}*"))
236
+ return FileResponse(file_path)
237
+ except Exception as e:
238
+ logger.error(f"File retrieval error: {e}")
239
+ raise HTTPException(404, "Fichier non trouvé")
240
+
241
+ # Gestion des erreurs globales
242
  @app.exception_handler(HTTPException)
243
  async def http_exception_handler(request, exc):
244
  return JSONResponse(
 
254
  content={"detail": "Une erreur interne est survenue"},
255
  )
256
 
257
+ # Montage des fichiers statiques
258
  app.mount("/", StaticFiles(directory=BASE_DIR, html=True), name="static")
259
 
260
  if __name__ == "__main__":
261
  import uvicorn
262
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)