File size: 2,099 Bytes
3567e13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from typing import Optional

import requests
from langchain.llms.base import LLM

from util.token_access import load_token

token = load_token()

model = "meta-llama/Llama-3.2-3B-Instruct"

class GuiChat(LLM):
    """GUI LLM wrapper usando login via token."""

    chatbot: Optional[object] = None
    auth_token: Optional[str] = None
    conversation: Optional[str] = None
    model: Optional[str] = model

    temperature: Optional[float] = 0.9
    top_p: Optional[float] = 0.5
    repetition_penalty: Optional[float] = 1.2
    top_k: Optional[int] = 20
    truncate: Optional[int] = 512
    max_new_tokens: Optional[int] = 512
    stream_resp: Optional[bool] = True
    log: Optional[bool] = True
    avg_response_time: float = 0.0

    def _llm_type(self):
        """Define o tipo de LLM para HuggingChat."""
        return "huggingface"

    def _call(self, prompt: str) -> str:
        """Chama o modelo Hugging Face e retorna a resposta."""
        headers = {
            "Authorization": f"Bearer {self.auth_token}",
            "Content-Type": "application/json",
        }

        endpoint = f"https://api-inference.huggingface.co/models/{self.model}"

        payload = {
            "inputs": prompt,
            "parameters": {
                "temperature": self.temperature,
                "max_new_tokens": self.max_new_tokens,
                "top_p": self.top_p,
                "top_k": self.top_k,
                "repetition_penalty": self.repetition_penalty,
                "truncate": self.truncate,
            },
        }

        response = requests.post(endpoint, headers=headers, json=payload)

        if response.status_code == 200:
            return response.json()[0]["generated_text"]
        else:
            return f"Erro: {response.status_code}, {response.text}"

    def get_avg_response_time(self):
        """Retorna o tempo médio de resposta."""
        return self.avg_response_time


chatbot = GuiChat(auth_token=token)


#TEST-BOT
"""
while True:
    ask = input("Digite aqui: ")
    resposta = chatbot._call(ask)
    print(f">>> {resposta}")
"""