1dm commited on
Commit
018c8f2
·
verified ·
1 Parent(s): 0fc7fa2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -27
app.py CHANGED
@@ -1,33 +1,58 @@
1
- # Fichier: app.py (VERSION CORRIGÉE FINALE)
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
- from transformers import AutoModelForCausalLM, AutoTokenizer
5
  import torch
 
6
 
7
  # --- Configuration du Modèle ---
8
- # Nouveau Modèle : Phi-3 Mini (3.8B) - Optimisé pour les systèmes légers/CPU
9
  model_id = "microsoft/Phi-3-mini-4k-instruct"
10
-
11
- # Détecter le périphérique : forcer le CPU
12
  device = torch.device("cpu")
13
 
 
 
 
 
 
 
 
 
 
 
14
  # Charger le Tokenizer et le Modèle
15
  try:
16
  tokenizer = AutoTokenizer.from_pretrained(model_id)
17
 
18
- # CORRECTION CRITIQUE 1: Stabilité du chargement sur CPU
19
- # 1. Ajout de torch_dtype=torch.float32 pour assurer la compatibilité CPU.
20
- # 2. Suppression de device_map=device (le .to(device) final est suffisant).
21
  model = AutoModelForCausalLM.from_pretrained(
22
  model_id,
23
- torch_dtype=torch.float32,
24
- trust_remote_code=True # Nécessaire pour Phi-3
25
- ).to(device)
26
- model.eval()
27
- print(f"Modèle {model_id} chargé sur CPU.")
28
- except Exception as e:
29
- print(f"Erreur lors du chargement du modèle : {e}")
30
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  app = FastAPI(
33
  title="NLP Space - Phi-3 Mini API (CPU)",
@@ -48,12 +73,12 @@ def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: i
48
  {"role": "user", "content": user_prompt}
49
  ]
50
 
51
- # Appliquer le template de chat du tokenizer
52
  text_to_generate = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
53
 
54
- # Exemple de format: "<|user|>\nInstruction\n<|end|>\n<|assistant|>"
55
-
56
- inputs = tokenizer(text_to_generate, return_tensors="pt").to(device)
 
57
 
58
  with torch.no_grad():
59
  output = model.generate(
@@ -62,13 +87,12 @@ def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: i
62
  do_sample=True,
63
  temperature=temperature,
64
  pad_token_id=tokenizer.eos_token_id,
65
- # CORRECTION CRITIQUE 2: Désactiver le cache pour contourner le bug 'DynamicCache'
66
- use_cache=False
67
  )
68
 
69
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
70
 
71
- # Nettoyage : retirer l'instruction initiale pour ne garder que la réponse
72
  response_start_tag = "<|assistant|>"
73
  if response_start_tag in generated_text:
74
  return generated_text.split(response_start_tag, 1)[1].strip()
@@ -76,7 +100,7 @@ def generate_text_from_model(system_prompt: str, user_prompt: str, max_tokens: i
76
  return generated_text.strip()
77
 
78
 
79
- # --- Endpoints (Identiques) ---
80
 
81
  @app.post("/generate")
82
  async def generate(request: PromptRequest):
@@ -91,12 +115,11 @@ async def generate(request: PromptRequest):
91
  )
92
  return {"result": result}
93
  except Exception as e:
94
- # Retourne l'erreur Python pour le diagnostic (comme vous l'avez fait)
95
  return {"error": str(e)}
96
 
97
  @app.post("/summarize")
98
  async def summarize(request: PromptRequest):
99
- """Résumé d'un long texte."""
100
  system_prompt = "Tu es un expert en résumé concis et précis. Ton objectif est de résumer le texte fourni de manière à en conserver l'idée principale."
101
  user_prompt = f"Résume le texte suivant de manière concise et factuelle:\n\n---\n\n{request.prompt}"
102
  try:
@@ -112,7 +135,7 @@ async def summarize(request: PromptRequest):
112
 
113
  @app.post("/classify")
114
  async def classify(request: PromptRequest):
115
- """Classification du sentiment, du thème ou de la catégorie d'un texte."""
116
  system_prompt = "Tu es un expert en classification. Réponds uniquement avec l'étiquette de classification demandée sans phrases supplémentaires."
117
  user_prompt = request.prompt
118
  try:
 
1
+ # Fichier: app.py (VERSION CORRIGÉE FINALE - OPTIMISÉE POUR LA MÉMOIRE DU SPACE)
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
5
  import torch
6
+ import os
7
 
8
  # --- Configuration du Modèle ---
 
9
  model_id = "microsoft/Phi-3-mini-4k-instruct"
 
 
10
  device = torch.device("cpu")
11
 
12
+ # --- Stratégie de chargement pour économiser la mémoire (Quantisation) ---
13
+ # Si le Space a un GPU/CUDA, la quantisation sera utilisée, réduisant la RAM par 8.
14
+ # Si le Space est CPU seulement, cette tentative échouera, et nous utiliserons le fallback float32.
15
+ quantization_config = BitsAndBytesConfig(
16
+ load_in_4bit=True,
17
+ bnb_4bit_quant_type="nf4",
18
+ bnb_4bit_compute_dtype=torch.bfloat16,
19
+ bnb_4bit_use_double_quant=True,
20
+ )
21
+
22
  # Charger le Tokenizer et le Modèle
23
  try:
24
  tokenizer = AutoTokenizer.from_pretrained(model_id)
25
 
26
+ # TENTATIVE 1 : Chargement avec Quantisation 4-bit (Méthode recommandée)
27
+ print("Tentative de chargement avec quantisation 4-bit...")
28
+ # Le chargement en 4-bit nécessite device_map="auto"
29
  model = AutoModelForCausalLM.from_pretrained(
30
  model_id,
31
+ quantization_config=quantization_config,
32
+ device_map="auto",
33
+ trust_remote_code=True
34
+ )
35
+ print(f"Modèle {model_id} chargé et quantifié.")
36
+
37
+ except Exception as e_quant:
38
+ # Si la quantisation échoue (souvent sans GPU), on revient à la version CPU
39
+ print(f"Échec de la quantisation : {e_quant}. Tentative de chargement float32 CPU (Attention: peut causer OOM).")
40
+
41
+ # TENTATIVE 2 : Fallback sur le chargement float32 CPU (Votre code initial, mais avec fix du bug)
42
+ try:
43
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
44
+ model = AutoModelForCausalLM.from_pretrained(
45
+ model_id,
46
+ torch_dtype=torch.float32,
47
+ trust_remote_code=True
48
+ ).to(device)
49
+ print(f"Modèle {model_id} chargé sur CPU (Float32).")
50
+ except Exception as e_cpu:
51
+ print(f"Échec critique du chargement CPU : {e_cpu}")
52
+ # Si même float32 échoue, vous avez BESOIN de plus de RAM pour votre Space.
53
+ raise e_cpu
54
+
55
+ model.eval()
56
 
57
  app = FastAPI(
58
  title="NLP Space - Phi-3 Mini API (CPU)",
 
73
  {"role": "user", "content": user_prompt}
74
  ]
75
 
 
76
  text_to_generate = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=False)
77
 
78
+ # Trouver le device réel du modèle pour y placer les inputs (nécessaire après le chargement device_map)
79
+ # Assurez-vous que le modèle est correctement placé, en le forçant sur CPU si nécessaire.
80
+ real_device = model.device if model.device.type != 'meta' else torch.device("cpu")
81
+ inputs = tokenizer(text_to_generate, return_tensors="pt").to(real_device)
82
 
83
  with torch.no_grad():
84
  output = model.generate(
 
87
  do_sample=True,
88
  temperature=temperature,
89
  pad_token_id=tokenizer.eos_token_id,
90
+ use_cache=False # CORRECTION CRITIQUE 2: Fixe le bug DynamicCache
 
91
  )
92
 
93
  generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
94
 
95
+ # Nettoyage
96
  response_start_tag = "<|assistant|>"
97
  if response_start_tag in generated_text:
98
  return generated_text.split(response_start_tag, 1)[1].strip()
 
100
  return generated_text.strip()
101
 
102
 
103
+ # --- Endpoints (Inchangés) ---
104
 
105
  @app.post("/generate")
106
  async def generate(request: PromptRequest):
 
115
  )
116
  return {"result": result}
117
  except Exception as e:
 
118
  return {"error": str(e)}
119
 
120
  @app.post("/summarize")
121
  async def summarize(request: PromptRequest):
122
+ # ... (code inchangé) ...
123
  system_prompt = "Tu es un expert en résumé concis et précis. Ton objectif est de résumer le texte fourni de manière à en conserver l'idée principale."
124
  user_prompt = f"Résume le texte suivant de manière concise et factuelle:\n\n---\n\n{request.prompt}"
125
  try:
 
135
 
136
  @app.post("/classify")
137
  async def classify(request: PromptRequest):
138
+ # ... (code inchangé) ...
139
  system_prompt = "Tu es un expert en classification. Réponds uniquement avec l'étiquette de classification demandée sans phrases supplémentaires."
140
  user_prompt = request.prompt
141
  try: