AxL95 commited on
Commit
c7b3fc2
·
verified ·
1 Parent(s): efdbda5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -15
app.py CHANGED
@@ -4,12 +4,15 @@ from fastapi.responses import JSONResponse
4
  from fastapi.staticfiles import StaticFiles
5
  from huggingface_hub import InferenceClient
6
  from sentence_transformers import SentenceTransformer
7
-
 
8
  from fastapi import Request
9
  import requests
10
  import numpy as np
11
  import argparse
12
  import os
 
 
13
 
14
  HOST = os.environ.get("API_URL", "0.0.0.0")
15
  PORT = os.environ.get("PORT", 7860)
@@ -30,8 +33,28 @@ app.add_middleware(
30
  allow_headers=["*"],
31
  )
32
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- app = FastAPI()
35
  embedder = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v1')
36
 
37
  @app.post("/api/embed")
@@ -64,25 +87,24 @@ async def chat(request: Request):
64
  user_message = data.get("message", "").strip()
65
  if not user_message:
66
  raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
 
 
 
 
 
 
 
67
 
68
  try:
69
- # Appel au modèle en mode chat
70
- completion = hf_client.chat.completions.create(
71
- model="mistralai/Mistral-7B-Instruct-v0.3",
72
- messages=[
73
- {"role": "system", "content": "Tu es un assistant médical spécialisé en schizophrénie."},
74
- {"role": "user", "content": user_message}
75
- ],
76
- max_tokens=512,
77
- temperature=0.7,
78
  )
79
-
80
- bot_msg = completion.choices[0].message.content
81
  return {"response": bot_msg}
82
 
83
  except Exception as e:
84
- # En cas d'erreur d'inférence
85
- raise HTTPException(status_code=502, detail=f"Erreur d'inférence HF : {e}")
86
 
87
 
88
  @app.get("/data")
 
4
  from fastapi.staticfiles import StaticFiles
5
  from huggingface_hub import InferenceClient
6
  from sentence_transformers import SentenceTransformer
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
8
+ import torch
9
  from fastapi import Request
10
  import requests
11
  import numpy as np
12
  import argparse
13
  import os
14
+ from fastapi import HTTPException
15
+
16
 
17
  HOST = os.environ.get("API_URL", "0.0.0.0")
18
  PORT = os.environ.get("PORT", 7860)
 
33
  allow_headers=["*"],
34
  )
35
 
36
+ # Charge le tokenizer et le modèle
37
+ tokenizer = AutoTokenizer.from_pretrained(
38
+ "mistralai/Mistral-7B-Instruct-v0.3",
39
+ trust_remote_code=True
40
+ )
41
+ model = AutoModelForCausalLM.from_pretrained(
42
+ "mistralai/Mistral-7B-Instruct-v0.3",
43
+ trust_remote_code=True,
44
+ torch_dtype=torch.float32, # float32 sur CPU
45
+ low_cpu_mem_usage=True # réduit l’empreinte mémoire
46
+ )
47
+ # Crée un pipeline "chat" (text-generation) préconfiguré
48
+ chat_pipeline = pipeline(
49
+ "text-generation",
50
+ model=model,
51
+ tokenizer=tokenizer,
52
+ device=-1, # -1 = CPU
53
+ max_new_tokens=512,
54
+ temperature=0.7,
55
+ do_sample=True
56
+ )
57
 
 
58
  embedder = SentenceTransformer('sentence-transformers/distiluse-base-multilingual-cased-v1')
59
 
60
  @app.post("/api/embed")
 
87
  user_message = data.get("message", "").strip()
88
  if not user_message:
89
  raise HTTPException(status_code=400, detail="Le champ 'message' est requis.")
90
+
91
+ # Construit le prompt
92
+ prompt = (
93
+ "Tu es un assistant médical spécialisé en schizophrénie.\n"
94
+ "Utilisateur : " + user_message + "\n"
95
+ "Assistant :"
96
+ )
97
 
98
  try:
99
+ outputs = chat_pipeline(
100
+ prompt,
101
+ return_full_text=False
 
 
 
 
 
 
102
  )
103
+ bot_msg = outputs[0]["generated_text"].strip()
 
104
  return {"response": bot_msg}
105
 
106
  except Exception as e:
107
+ raise HTTPException(status_code=502, detail=f"Erreur d’inférence locale : {e}")
 
108
 
109
 
110
  @app.get("/data")