SamiKLN commited on
Commit
7d4c600
·
verified ·
1 Parent(s): 756c809

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +9 -58
main.py CHANGED
@@ -38,17 +38,22 @@ BASE_DIR = Path(__file__).parent
38
  UPLOAD_FOLDER = BASE_DIR / "uploads"
39
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
40
 
 
 
 
 
 
41
  # Modèles Hugging Face
42
  HF_TOKEN = os.getenv("HF_TOKEN")
43
  client = InferenceClient(token=HF_TOKEN)
44
  MODELS = {
45
  "summary": "facebook/bart-large-cnn",
46
  "caption": "Salesforce/blip-image-captioning-large",
47
- "qa": "distilbert-base-cased-distilled-squad" # Nouveau modèle QA plus léger
48
  }
49
 
50
  # Pipeline QA pour distilbert-base-cased-distilled-squad
51
- qa_pipeline = pipeline("question-answering", model=MODELS["qa"], tokenizer=MODELS["qa"])
52
 
53
  # Modèles Pydantic
54
  class FileInfo(BaseModel):
@@ -69,7 +74,7 @@ class QARequest(BaseModel):
69
  file_id: Optional[str] = None
70
  question: str
71
 
72
- # Fonctions utilitaires
73
  def extract_text_from_pdf(file_path: str) -> str:
74
  try:
75
  doc = fitz.open(file_path)
@@ -144,11 +149,7 @@ async def process_uploaded_file(file: UploadFile) -> FileInfo:
144
  async def test_api():
145
  return {"status": "API working", "environment": "Hugging Face" if os.environ.get("HF_SPACE") else "Local"}
146
 
147
- # Endpoints
148
- @app.get("/api")
149
- async def api_root():
150
- return {"status": "API is running"}
151
-
152
  @app.post("/api/upload")
153
  async def upload_files(files: List[UploadFile] = File(...)):
154
  logger.info(f"Upload request received with {len(files)} files")
@@ -163,47 +164,6 @@ async def upload_files(files: List[UploadFile] = File(...)):
163
  logger.error(f"Upload error: {e}")
164
  raise HTTPException(500, f"Erreur lors de l'upload: {str(e)}")
165
 
166
- @app.post("/api/summarize")
167
- async def summarize_document(request: SummaryRequest):
168
- try:
169
- file_path = next(f for f in UPLOAD_FOLDER.glob(f"{request.file_id}*"))
170
- text = ""
171
-
172
- if file_path.suffix == ".pdf":
173
- text = extract_text_from_pdf(str(file_path))
174
- else:
175
- with open(file_path, "r", encoding="utf-8") as f:
176
- text = f.read()
177
-
178
- summary = client.summarization(
179
- text=text,
180
- model=MODELS["summary"],
181
- parameters={"max_length": request.max_length}
182
- )
183
-
184
- return {"summary": summary}
185
- except Exception as e:
186
- logger.error(f"Summarization error: {e}")
187
- raise HTTPException(500, f"Erreur de résumé: {str(e)}")
188
-
189
- @app.post("/api/caption")
190
- async def caption_image(request: CaptionRequest):
191
- try:
192
- file_path = next(f for f in UPLOAD_FOLDER.glob(f"{request.file_id}*"))
193
-
194
- with open(file_path, "rb") as image_file:
195
- image_data = image_file.read()
196
-
197
- caption = client.image_to_text(
198
- image=image_data,
199
- model=MODELS["caption"]
200
- )
201
-
202
- return {"caption": caption}
203
- except Exception as e:
204
- logger.error(f"Captioning error: {e}")
205
- raise HTTPException(500, f"Erreur de description: {str(e)}")
206
-
207
  @app.post("/api/answer")
208
  async def answer_question(request: QARequest):
209
  try:
@@ -236,15 +196,6 @@ async def answer_question(request: QARequest):
236
  logger.error(f"QA error: {e}")
237
  raise HTTPException(500, f"Erreur de réponse: {str(e)}")
238
 
239
- @app.get("/api/file/{file_id}")
240
- async def get_file(file_id: str):
241
- try:
242
- file_path = next(f for f in UPLOAD_FOLDER.glob(f"{file_id}*"))
243
- return FileResponse(file_path)
244
- except Exception as e:
245
- logger.error(f"File retrieval error: {e}")
246
- raise HTTPException(404, "Fichier non trouvé")
247
-
248
  # Gestion des erreurs
249
  @app.exception_handler(HTTPException)
250
  async def http_exception_handler(request, exc):
 
38
  UPLOAD_FOLDER = BASE_DIR / "uploads"
39
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
40
 
41
+ # Configuration du cache pour Hugging Face
42
+ TEMP_CACHE_DIR = "/tmp/huggingface_cache"
43
+ os.environ["TRANSFORMERS_CACHE"] = TEMP_CACHE_DIR
44
+ os.environ["HF_HOME"] = TEMP_CACHE_DIR
45
+
46
  # Modèles Hugging Face
47
  HF_TOKEN = os.getenv("HF_TOKEN")
48
  client = InferenceClient(token=HF_TOKEN)
49
  MODELS = {
50
  "summary": "facebook/bart-large-cnn",
51
  "caption": "Salesforce/blip-image-captioning-large",
52
+ "qa": "distilbert-base-cased-distilled-squad" # Modèle QA léger
53
  }
54
 
55
  # Pipeline QA pour distilbert-base-cased-distilled-squad
56
+ qa_pipeline = pipeline("question-answering", model=MODELS["qa"], tokenizer=MODELS["qa"], cache_dir=TEMP_CACHE_DIR)
57
 
58
  # Modèles Pydantic
59
  class FileInfo(BaseModel):
 
74
  file_id: Optional[str] = None
75
  question: str
76
 
77
+ # Fonctions utilitaires (extraction de texte, etc.)
78
  def extract_text_from_pdf(file_path: str) -> str:
79
  try:
80
  doc = fitz.open(file_path)
 
149
  async def test_api():
150
  return {"status": "API working", "environment": "Hugging Face" if os.environ.get("HF_SPACE") else "Local"}
151
 
152
+ # Endpoints (upload, summarize, caption, answer, etc.)
 
 
 
 
153
  @app.post("/api/upload")
154
  async def upload_files(files: List[UploadFile] = File(...)):
155
  logger.info(f"Upload request received with {len(files)} files")
 
164
  logger.error(f"Upload error: {e}")
165
  raise HTTPException(500, f"Erreur lors de l'upload: {str(e)}")
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  @app.post("/api/answer")
168
  async def answer_question(request: QARequest):
169
  try:
 
196
  logger.error(f"QA error: {e}")
197
  raise HTTPException(500, f"Erreur de réponse: {str(e)}")
198
 
 
 
 
 
 
 
 
 
 
199
  # Gestion des erreurs
200
  @app.exception_handler(HTTPException)
201
  async def http_exception_handler(request, exc):