schat / moe_tools.py
VeuReu's picture
Update moe_tools.py
2e6bd63 verified
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
class SalamandraClient:
def __init__(self, model_id="BSC-LT/salamandra-7b-instruct"):
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map="auto",
torch_dtype=torch.bfloat16
)
def chat(self, prompt) -> str:
encodings = self.tokenizer(
prompt,
return_tensors="pt",
padding=True,
)
inputs = encodings["input_ids"].to(self.model.device)
attention_mask = encodings["attention_mask"].to(self.model.device)
outputs = self.model.generate(
input_ids=inputs,
attention_mask=attention_mask,
pad_token_id=self.tokenizer.pad_token_id,
max_new_tokens=300,
temperature=0.01,
top_k=50,
top_p=0.9
)
generated_tokens = outputs[0][inputs.shape[1]:]
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)