| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| class SalamandraClient: | |
| def __init__(self, model_id="BSC-LT/salamandra-7b-instruct"): | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_id) | |
| self.model = AutoModelForCausalLM.from_pretrained( | |
| model_id, | |
| device_map="auto", | |
| torch_dtype=torch.bfloat16 | |
| ) | |
| def chat(self, prompt) -> str: | |
| encodings = self.tokenizer( | |
| prompt, | |
| return_tensors="pt", | |
| padding=True, | |
| ) | |
| inputs = encodings["input_ids"].to(self.model.device) | |
| attention_mask = encodings["attention_mask"].to(self.model.device) | |
| outputs = self.model.generate( | |
| input_ids=inputs, | |
| attention_mask=attention_mask, | |
| pad_token_id=self.tokenizer.pad_token_id, | |
| max_new_tokens=300, | |
| temperature=0.01, | |
| top_k=50, | |
| top_p=0.9 | |
| ) | |
| generated_tokens = outputs[0][inputs.shape[1]:] | |
| return self.tokenizer.decode(generated_tokens, skip_special_tokens=True) |