1dm commited on
Commit
0fc7fa2
·
verified ·
1 Parent(s): add108b

resolve bug

Browse files
Files changed (1) hide show
  1. app.py +11 -7
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # Fichier: app.py
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -14,10 +14,13 @@ device = torch.device("cpu")
14
  # Charger le Tokenizer et le Modèle
15
  try:
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
- # Chargement spécifique pour CPU, sans besoin de types de données GPU (float16)
 
 
 
18
  model = AutoModelForCausalLM.from_pretrained(
19
  model_id,
20
- device_map=device,
21
  trust_remote_code=True # Nécessaire pour Phi-3
22
  ).to(device)
23
  model.eval()
@@ -48,7 +51,6 @@ def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: i
48
  # Appliquer le template de chat du tokenizer
49
  text_to_generate = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
50
 
51
- # Remplacer le template de chat par l'instruction de base pour l'inférence
52
  # Exemple de format: "<|user|>\nInstruction\n<|end|>\n<|assistant|>"
53
 
54
  inputs = tokenizer(text_to_generate, return_tensors="pt").to(device)
@@ -59,13 +61,14 @@ def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: i
59
  max_new_tokens=max_tokens,
60
  do_sample=True,
61
  temperature=temperature,
62
- pad_token_id=tokenizer.eos_token_id
 
 
63
  )
64
 
65
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
66
 
67
  # Nettoyage : retirer l'instruction initiale pour ne garder que la réponse
68
- # Le nettoyage doit être adapté au format de sortie de Phi-3
69
  response_start_tag = "<|assistant|>"
70
  if response_start_tag in generated_text:
71
  return generated_text.split(response_start_tag, 1)[1].strip()
@@ -73,7 +76,7 @@ def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: i
73
  return generated_text.strip()
74
 
75
 
76
- # --- Endpoints (Identiques au plan initial) ---
77
 
78
  @app.post("/generate")
79
  async def generate(request: PromptRequest):
@@ -88,6 +91,7 @@ async def generate(request: PromptRequest):
88
  )
89
  return {"result": result}
90
  except Exception as e:
 
91
  return {"error": str(e)}
92
 
93
  @app.post("/summarize")
 
1
+ # Fichier: app.py (VERSION CORRIGÉE FINALE)
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
14
  # Charger le Tokenizer et le Modèle
15
  try:
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
+
18
+ # CORRECTION CRITIQUE 1: Stabilité du chargement sur CPU
19
+ # 1. Ajout de torch_dtype=torch.float32 pour assurer la compatibilité CPU.
20
+ # 2. Suppression de device_map=device (le .to(device) final est suffisant).
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_id,
23
+ torch_dtype=torch.float32,
24
  trust_remote_code=True # Nécessaire pour Phi-3
25
  ).to(device)
26
  model.eval()
 
51
  # Appliquer le template de chat du tokenizer
52
  text_to_generate = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
53
 
 
54
  # Exemple de format: "<|user|>\nInstruction\n<|end|>\n<|assistant|>"
55
 
56
  inputs = tokenizer(text_to_generate, return_tensors="pt").to(device)
 
61
  max_new_tokens=max_tokens,
62
  do_sample=True,
63
  temperature=temperature,
64
+ pad_token_id=tokenizer.eos_token_id,
65
+ # CORRECTION CRITIQUE 2: Désactiver le cache pour contourner le bug 'DynamicCache'
66
+ use_cache=False
67
  )
68
 
69
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
 
71
  # Nettoyage : retirer l'instruction initiale pour ne garder que la réponse
 
72
  response_start_tag = "<|assistant|>"
73
  if response_start_tag in generated_text:
74
  return generated_text.split(response_start_tag, 1)[1].strip()
 
76
  return generated_text.strip()
77
 
78
 
79
+ # --- Endpoints (Identiques) ---
80
 
81
  @app.post("/generate")
82
  async def generate(request: PromptRequest):
 
91
  )
92
  return {"result": result}
93
  except Exception as e:
94
+ # Retourne l'erreur Python pour le diagnostic (comme vous l'avez fait)
95
  return {"error": str(e)}
96
 
97
  @app.post("/summarize")