File size: 1,298 Bytes
8aebe10
ab8df1d
 
 
 
 
 
 
 
 
 
 
 
 
 
8aebe10
 
ab8df1d
 
 
 
 
 
 
 
8aebe10
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS

class LocalLLM:
    def __init__(self):
        self.pipe = None
        self._load_any()

    def _load_any(self):
        for mid in OPEN_LLM_CANDIDATES:
            try:
                tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
                mdl = AutoModelForCausalLM.from_pretrained(
                    mid, device_map="auto",
                    torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
                    trust_remote_code=True
                )
                self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok)
                return
            except Exception:
                continue

    def chat(self, prompt: str) -> Optional[str]:
        if not self.pipe: return None
        out = self.pipe(
            prompt, max_new_tokens=LOCAL_MAX_NEW_TOKENS,
            do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.12,
            eos_token_id=self.pipe.tokenizer.eos_token_id
        )
        text = out[0]["generated_text"]
        return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()