Update app.py
Browse files
app.py
CHANGED
|
@@ -199,50 +199,32 @@ async def chat(request: Request):
|
|
| 199 |
data = await request.json()
|
| 200 |
user_message = data.get("message", "").strip()
|
| 201 |
if not user_message:
|
| 202 |
-
from fastapi import HTTPException
|
| 203 |
raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
|
| 204 |
|
| 205 |
try:
|
| 206 |
-
# Utiliser
|
| 207 |
response = hf_client.text_generation(
|
| 208 |
model="mistralai/Mistral-7B-Instruct-v0.3",
|
| 209 |
prompt=f"<s>[INST] Tu es un assistant médical spécialisé en schizophrénie. Réponds à cette question: {user_message} [/INST]",
|
| 210 |
max_new_tokens=512,
|
| 211 |
-
temperature=0.7
|
|
|
|
| 212 |
)
|
| 213 |
|
| 214 |
return {"response": response}
|
| 215 |
|
| 216 |
except Exception as e:
|
| 217 |
-
from fastapi import HTTPException
|
| 218 |
import traceback
|
| 219 |
-
print(f"Erreur détaillée: {traceback.format_exc()}")
|
| 220 |
raise HTTPException(status_code=502, detail=f"Erreur d'inférence HF : {str(e)}")
|
| 221 |
|
| 222 |
|
|
|
|
| 223 |
@app.get("/data")
|
| 224 |
async def get_data():
|
| 225 |
data = {"data": np.random.rand(100).tolist()}
|
| 226 |
return JSONResponse(data)
|
| 227 |
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
| 231 |
-
|
| 232 |
-
if __name__ == "__main__":
|
| 233 |
-
import uvicorn
|
| 234 |
-
|
| 235 |
-
print(args)
|
| 236 |
-
uvicorn.run(
|
| 237 |
-
"app:app",
|
| 238 |
-
host=args.host,
|
| 239 |
-
port=args.port,
|
| 240 |
-
reload=args.reload,
|
| 241 |
-
ssl_certfile=args.ssl_certfile,
|
| 242 |
-
ssl_keyfile=args.ssl_keyfile,
|
| 243 |
-
)
|
| 244 |
-
|
| 245 |
-
|
| 246 |
# Endpoint pour récupérer toutes les conversations d'un utilisateur
|
| 247 |
@app.get("/api/conversations")
|
| 248 |
async def get_conversations(current_user: dict = Depends(get_current_user)):
|
|
@@ -376,4 +358,21 @@ async def delete_conversation(conversation_id: str, current_user: dict = Depends
|
|
| 376 |
|
| 377 |
return {"success": True}
|
| 378 |
except Exception as e:
|
| 379 |
-
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 199 |
data = await request.json()
|
| 200 |
user_message = data.get("message", "").strip()
|
| 201 |
if not user_message:
|
|
|
|
| 202 |
raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
|
| 203 |
|
| 204 |
try:
|
| 205 |
+
# Utiliser le provider novita comme demandé
|
| 206 |
response = hf_client.text_generation(
|
| 207 |
model="mistralai/Mistral-7B-Instruct-v0.3",
|
| 208 |
prompt=f"<s>[INST] Tu es un assistant médical spécialisé en schizophrénie. Réponds à cette question: {user_message} [/INST]",
|
| 209 |
max_new_tokens=512,
|
| 210 |
+
temperature=0.7,
|
| 211 |
+
provider="novita" # Spécifier le provider novita ici
|
| 212 |
)
|
| 213 |
|
| 214 |
return {"response": response}
|
| 215 |
|
| 216 |
except Exception as e:
|
|
|
|
| 217 |
import traceback
|
| 218 |
+
print(f"Erreur détaillée: {traceback.format_exc()}")
|
| 219 |
raise HTTPException(status_code=502, detail=f"Erreur d'inférence HF : {str(e)}")
|
| 220 |
|
| 221 |
|
| 222 |
+
|
| 223 |
@app.get("/data")
|
| 224 |
async def get_data():
|
| 225 |
data = {"data": np.random.rand(100).tolist()}
|
| 226 |
return JSONResponse(data)
|
| 227 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
# Endpoint pour récupérer toutes les conversations d'un utilisateur
|
| 229 |
@app.get("/api/conversations")
|
| 230 |
async def get_conversations(current_user: dict = Depends(get_current_user)):
|
|
|
|
| 358 |
|
| 359 |
return {"success": True}
|
| 360 |
except Exception as e:
|
| 361 |
+
raise HTTPException(status_code=500, detail=f"Erreur serveur: {str(e)}")
|
| 362 |
+
|
| 363 |
+
app.mount("/", StaticFiles(directory="static", html=True), name="static")
|
| 364 |
+
|
| 365 |
+
if __name__ == "__main__":
|
| 366 |
+
import uvicorn
|
| 367 |
+
|
| 368 |
+
print(args)
|
| 369 |
+
uvicorn.run(
|
| 370 |
+
"app:app",
|
| 371 |
+
host=args.host,
|
| 372 |
+
port=args.port,
|
| 373 |
+
reload=args.reload,
|
| 374 |
+
ssl_certfile=args.ssl_certfile,
|
| 375 |
+
ssl_keyfile=args.ssl_keyfile,
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
|