import torch from transformers import AutoTokenizer, AutoModelForCausalLM class SalamandraClient: def __init__(self, model_id="BSC-LT/salamandra-7b-instruct"): self.tokenizer = AutoTokenizer.from_pretrained(model_id) self.model = AutoModelForCausalLM.from_pretrained( model_id, device_map="auto", torch_dtype=torch.bfloat16 ) def chat(self, prompt) -> str: encodings = self.tokenizer( prompt, return_tensors="pt", padding=True, ) inputs = encodings["input_ids"].to(self.model.device) attention_mask = encodings["attention_mask"].to(self.model.device) outputs = self.model.generate( input_ids=inputs, attention_mask=attention_mask, pad_token_id=self.tokenizer.pad_token_id, max_new_tokens=300, # más grande si el texto es largo temperature=0.01, # control de creatividad top_k=50, # tokens más probables top_p=0.9 ) generated_tokens = outputs[0][inputs.shape[1]:] return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)