File size: 1,235 Bytes
4abe767
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

class SalamandraClient:
    def __init__(self, model_id="BSC-LT/salamandra-7b-instruct"):
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id,
            device_map="auto",
            torch_dtype=torch.bfloat16
        )

    def chat(self, prompt) -> str:
        encodings = self.tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
        )

        inputs = encodings["input_ids"].to(self.model.device)
        attention_mask = encodings["attention_mask"].to(self.model.device)

        outputs = self.model.generate(
            input_ids=inputs,
            attention_mask=attention_mask,
            pad_token_id=self.tokenizer.pad_token_id,
            max_new_tokens=300,  # más grande si el texto es largo
            temperature=0.01,         # control de creatividad
            top_k=50,                # tokens más probables
            top_p=0.9
        )
        
        generated_tokens = outputs[0][inputs.shape[1]:]
        
        return self.tokenizer.decode(generated_tokens, skip_special_tokens=True)