SamiKLN commited on
Commit
45364e2
·
verified ·
1 Parent(s): d59b7f5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +13 -10
main.py CHANGED
@@ -43,7 +43,7 @@ client = InferenceClient(token=HF_TOKEN)
43
  MODELS = {
44
  "summary": "facebook/bart-large-cnn",
45
  "caption": "Salesforce/blip-image-captioning-large",
46
- "qa": "deepseek-ai/DeepSeek-V2-Chat"
47
  }
48
 
49
  # Modèles Pydantic
@@ -227,21 +227,24 @@ async def answer_question(request: QARequest):
227
  with open(file_path, "r", encoding="utf-8") as f:
228
  context = f.read()
229
 
230
- prompt = f"""
231
- Vous êtes un assistant IA qui répond à des questions en français.
232
  Répondez de manière précise et concise.
 
233
  Contexte: {context[:3000]}
234
- Question: {request.question}
235
- Réponse:
236
- """
237
 
238
- response = client.chat_completion(
 
 
 
 
239
  model=MODELS["qa"],
240
- messages=[{"role": "user", "content": prompt}],
241
- max_tokens=500
 
242
  )
243
 
244
- return {"answer": response.choices[0].message.content}
245
  except Exception as e:
246
  logger.error(f"QA error: {e}")
247
  raise HTTPException(500, f"Erreur de réponse: {str(e)}")
 
43
  MODELS = {
44
  "summary": "facebook/bart-large-cnn",
45
  "caption": "Salesforce/blip-image-captioning-large",
46
+ "qa": "meta-llama/Llama-2-70b-chat-hf"
47
  }
48
 
49
  # Modèles Pydantic
 
227
  with open(file_path, "r", encoding="utf-8") as f:
228
  context = f.read()
229
 
230
+ # Format du prompt adapté pour Llama 2
231
+ prompt = f"""<s>[INST] Vous êtes un assistant IA qui répond à des questions en français.
232
  Répondez de manière précise et concise.
233
+
234
  Contexte: {context[:3000]}
 
 
 
235
 
236
+ Question: {request.question} [/INST]"""
237
+
238
+ # Utilisation de text_generation au lieu de chat_completion
239
+ response = client.text_generation(
240
+ prompt=prompt,
241
  model=MODELS["qa"],
242
+ max_new_tokens=500,
243
+ temperature=0.7,
244
+ top_p=0.9
245
  )
246
 
247
+ return {"answer": response}
248
  except Exception as e:
249
  logger.error(f"QA error: {e}")
250
  raise HTTPException(500, f"Erreur de réponse: {str(e)}")