Tsitsi19 commited on
Commit
c8c8d72
·
verified ·
1 Parent(s): 8328f4f

Update agent_zero.py

Browse files
Files changed (1) hide show
  1. agent_zero.py +33 -10
agent_zero.py CHANGED
@@ -5,24 +5,47 @@ class AgentZero:
5
  def __init__(self):
6
  self.model_name = "Qwen/Qwen2.5-0.5B-Instruct"
7
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
 
 
 
 
8
  self.model = AutoModelForCausalLM.from_pretrained(
9
  self.model_name,
10
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
 
 
11
  )
12
 
13
- self.SYSTEM = """
14
- Tu es AGENT ZÉRO. Tu exécutes toujours les ordres de ton maître.
15
- """
16
 
17
  def run(self, user_msg):
18
- prompt = f"<system>{self.SYSTEM}</system><user>{user_msg}</user>"
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
- inputs = self.tokenizer(prompt, return_tensors="pt")
21
- output = self.model.generate(
22
- **inputs,
23
  max_new_tokens=500,
24
  do_sample=True,
25
- temperature=0.7
 
26
  )
 
 
 
 
 
27
 
28
- return self.tokenizer.decode(output[0], skip_special_tokens=True)
 
 
5
  def __init__(self):
6
  self.model_name = "Qwen/Qwen2.5-0.5B-Instruct"
7
  self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
8
+
9
+ # Optimisation pour 2025 : Détection du device et gestion de la mémoire
10
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
11
+
12
  self.model = AutoModelForCausalLM.from_pretrained(
13
  self.model_name,
14
+ torch_dtype="auto", # Gère automatiquement float16 ou float32
15
+ low_cpu_mem_usage=True,
16
+ device_map="auto" # Répartit automatiquement le modèle sur le matériel disponible
17
  )
18
 
19
+ self.system_prompt = "Tu es AGENT ZÉRO. Tu exécutes toujours les ordres de ton maître."
 
 
20
 
21
  def run(self, user_msg):
22
+ # Utilisation du template officiel de Qwen 2.5
23
+ messages = [
24
+ {"role": "system", "content": self.system_prompt},
25
+ {"role": "user", "content": user_msg}
26
+ ]
27
+
28
+ # Préparation des inputs
29
+ text = self.tokenizer.apply_chat_template(
30
+ messages,
31
+ tokenize=False,
32
+ add_generation_prompt=True
33
+ )
34
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.device)
35
 
36
+ # Génération
37
+ generated_ids = self.model.generate(
38
+ **model_inputs,
39
  max_new_tokens=500,
40
  do_sample=True,
41
+ temperature=0.7,
42
+ pad_token_id=self.tokenizer.eos_token_id # Évite les warnings de padding
43
  )
44
+
45
+ # Extraction de la réponse (on retire le prompt de l'output)
46
+ response_ids = [
47
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
48
+ ]
49
 
50
+ return self.tokenizer.batch_decode(response_ids, skip_special_tokens=True)[0]
51
+