issoufzousko07 commited on
Commit
17e2475
·
verified ·
1 Parent(s): 69652de

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -49
app.py CHANGED
@@ -1,32 +1,29 @@
1
- from fastapi import FastAPI, HTTPException
 
 
2
  from pydantic import BaseModel
3
  import torch
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import uvicorn
6
- import os
7
-
8
- import traceback
9
-
10
- from fastapi.middleware.cors import CORSMiddleware
11
 
12
  app = FastAPI()
13
 
14
  app.add_middleware(
15
  CORSMiddleware,
16
  allow_origins=["*"],
17
- allow_credentials=False, # Plus sûr avec wildcards
18
  allow_methods=["*"],
19
  allow_headers=["*"],
20
  )
21
 
22
  # --- CONFIGURATION ---
23
- VERSION = "2.3 (Final Access Fix)"
24
  MODEL_ID = "issoufzousko07/BABA-IA-2B"
25
 
26
- print("="*50)
27
  print(f"🚀 Démarrage BABA API v{VERSION}")
28
  print(f"Chargement de {MODEL_ID}...")
29
- print("="*50)
30
 
31
  # Détection automatique
32
  if torch.cuda.is_available():
@@ -39,8 +36,7 @@ if torch.cuda.is_available():
39
  )
40
  else:
41
  device = "cpu"
42
- print("🐢 Mode CPU activé (Gemma 2B tient dans la RAM)")
43
- # Sur CPU, on évite device_map="auto" pour éviter les bugs d'offloading accelerate
44
  model = AutoModelForCausalLM.from_pretrained(
45
  MODEL_ID,
46
  torch_dtype=torch.float32
@@ -55,45 +51,41 @@ class ChatRequest(BaseModel):
55
 
56
  @app.post("/chat")
57
  async def chat(request: ChatRequest):
58
- try:
59
- print(f"📩 Reçu : {request.message}")
60
-
61
- # 1. MESSAGE + TEMPLATE
62
- messages = [{"role": "user", "content": request.message}]
63
-
64
- # 1. TEMPLATE -> TEXTE (Plus sûr)
65
- # On récupère le prompt complet sous forme de string
66
- text_prompt = tokenizer.apply_chat_template(
67
- messages,
68
- tokenize=False,
69
- add_generation_prompt=True
70
- )
71
-
72
- # 2. TEXTE -> TENSEURS
73
- # On tokenise explicitement pour avoir input_ids ET attention_mask
74
- inputs = tokenizer(text_prompt, return_tensors="pt").to(model.device)
75
 
76
- # 3. GÉNÉRATION
77
- # On passe **inputs pour envoyer input_ids + attention_mask correctement
78
- outputs = model.generate(
79
- **inputs,
80
- max_new_tokens=300,
81
- do_sample=True,
82
- temperature=0.7,
83
- top_p=0.9,
84
- )
 
 
85
 
86
- # 4. DÉCODAGE
87
- # inputs.input_ids.shape[-1] donne la longueur du prompt
88
- response = tokenizer.decode(outputs[0][inputs.input_ids.shape[-1]:], skip_special_tokens=True)
89
- print(f"📤 Réponse : {response.strip()[:50]}...")
90
-
91
- return {"response": response.strip()}
 
 
 
 
 
92
 
93
- except Exception as e:
94
- print("❌ ERREUR CRITIQUE :")
95
- traceback.print_exc() # Affiche toute l'erreur dans les logs
96
- return {"response": f"Erreur technique : {str(e)}"}
97
 
98
  if __name__ == "__main__":
99
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import StreamingResponse
4
  from pydantic import BaseModel
5
  import torch
6
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
7
  import uvicorn
8
+ from threading import Thread
9
+ import json
 
 
 
10
 
11
  app = FastAPI()
12
 
13
  app.add_middleware(
14
  CORSMiddleware,
15
  allow_origins=["*"],
16
+ allow_credentials=False,
17
  allow_methods=["*"],
18
  allow_headers=["*"],
19
  )
20
 
21
  # --- CONFIGURATION ---
22
+ VERSION = "3.0 (Streaming)"
23
  MODEL_ID = "issoufzousko07/BABA-IA-2B"
24
 
 
25
  print(f"🚀 Démarrage BABA API v{VERSION}")
26
  print(f"Chargement de {MODEL_ID}...")
 
27
 
28
  # Détection automatique
29
  if torch.cuda.is_available():
 
36
  )
37
  else:
38
  device = "cpu"
39
+ print("🐢 Mode CPU activé (Streaming activé pour compenser la lenteur)")
 
40
  model = AutoModelForCausalLM.from_pretrained(
41
  MODEL_ID,
42
  torch_dtype=torch.float32
 
51
 
52
  @app.post("/chat")
53
  async def chat(request: ChatRequest):
54
+ print(f"📩 Reçu (Stream) : {request.message}")
55
+
56
+ # 1. MESSAGE + TEMPLATE
57
+ messages = [{"role": "user", "content": request.message}]
58
+ text_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
59
+ inputs = tokenizer(text_prompt, return_tensors="pt").to(model.device)
 
 
 
 
 
 
 
 
 
 
 
60
 
61
+ # 2. CONFIG STREAMER
62
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
63
+
64
+ generation_kwargs = dict(
65
+ **inputs,
66
+ streamer=streamer,
67
+ max_new_tokens=300,
68
+ do_sample=True,
69
+ temperature=0.7,
70
+ top_p=0.9,
71
+ )
72
 
73
+ # 3. GÉNÉRATION DANS UN THREAD (Non-bloquant)
74
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
75
+ thread.start()
76
+
77
+ # 4. GÉNÉRATEUR DE RÉPONSE SSE
78
+ def stream_response():
79
+ full_text = ""
80
+ for token in streamer:
81
+ full_text += token
82
+ # Format SSE: data: {"token": "..."}\n\n
83
+ yield f"data: {json.dumps({'token': token, 'text': full_text})}\n\n"
84
 
85
+ # Signal de fin
86
+ yield "data: [DONE]\n\n"
87
+
88
+ return StreamingResponse(stream_response(), media_type="text/event-stream")
89
 
90
  if __name__ == "__main__":
91
  uvicorn.run(app, host="0.0.0.0", port=7860)