|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
|
|
class SalamandraClient: |
|
|
def __init__(self, model_id="BSC-LT/salamandra-7b-instruct"): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
|
model_id, |
|
|
device_map="auto", |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
|
|
|
def chat(self, prompt) -> str: |
|
|
encodings = self.tokenizer( |
|
|
prompt, |
|
|
return_tensors="pt", |
|
|
padding=True, |
|
|
) |
|
|
|
|
|
inputs = encodings["input_ids"].to(self.model.device) |
|
|
attention_mask = encodings["attention_mask"].to(self.model.device) |
|
|
|
|
|
outputs = self.model.generate( |
|
|
input_ids=inputs, |
|
|
attention_mask=attention_mask, |
|
|
pad_token_id=self.tokenizer.pad_token_id, |
|
|
max_new_tokens=300, |
|
|
temperature=0.01, |
|
|
top_k=50, |
|
|
top_p=0.9 |
|
|
) |
|
|
|
|
|
generated_tokens = outputs[0][inputs.shape[1]:] |
|
|
|
|
|
return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |