VeuReu commited on
Commit
77da823
·
verified ·
1 Parent(s): ba54b37

Upload moe_tools.py

Browse files
Files changed (1) hide show
  1. moe_tools.py +35 -0
moe_tools.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM
3
+
4
+ class SalamandraClient:
5
+ def __init__(self, model_id="BSC-LT/salamandra-7b-instruct"):
6
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
7
+ self.model = AutoModelForCausalLM.from_pretrained(
8
+ model_id,
9
+ device_map="auto",
10
+ torch_dtype=torch.bfloat16
11
+ )
12
+
13
+ def chat(self, prompt) -> str:
14
+ encodings = self.tokenizer(
15
+ prompt,
16
+ return_tensors="pt",
17
+ padding=True,
18
+ )
19
+
20
+ inputs = encodings["input_ids"].to(self.model.device)
21
+ attention_mask = encodings["attention_mask"].to(self.model.device)
22
+
23
+ outputs = self.model.generate(
24
+ input_ids=inputs,
25
+ attention_mask=attention_mask,
26
+ pad_token_id=self.tokenizer.pad_token_id,
27
+ max_new_tokens=300, # más grande si el texto es largo
28
+ temperature=0.01, # control de creatividad
29
+ top_k=50, # tokens más probables
30
+ top_p=0.9
31
+ )
32
+
33
+ generated_tokens = outputs[0][inputs.shape[1]:]
34
+
35
+ return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)