stools / salamandra_tools.py
VeuReu's picture
Rename moe_tools.py to salamandra_tools.py
f09fbf3 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class SalamandraClient:
def __init__(self, model_id="BSC-LT/salamandra-7b-instruct"):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
def chat(self, prompt) -> str:
encodings = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
)
inputs = encodings["input_ids"].to(self.model.device)
attention_mask = encodings["attention_mask"].to(self.model.device)
outputs = self.model.generate(
input_ids=inputs,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.pad_token_id,
max_new_tokens=300, # más grande si el texto es largo
temperature=0.01, # control de creatividad
top_k=50, # tokens más probables
top_p=0.9
)
generated_tokens = outputs[0][inputs.shape[1]:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)