Spaces:
Sleeping
Sleeping
| from typing import Optional | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS | |
| class LocalLLM: | |
| def __init__(self): | |
| self.pipe = None | |
| self._load_any() | |
| def _load_any(self): | |
| for mid in OPEN_LLM_CANDIDATES: | |
| try: | |
| tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True) | |
| mdl = AutoModelForCausalLM.from_pretrained( | |
| mid, device_map="auto", | |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| trust_remote_code=True | |
| ) | |
| self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok) | |
| return | |
| except Exception: | |
| continue | |
| def chat(self, prompt: str) -> Optional[str]: | |
| if not self.pipe: return None | |
| out = self.pipe( | |
| prompt, max_new_tokens=LOCAL_MAX_NEW_TOKENS, | |
| do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.12, | |
| eos_token_id=self.pipe.tokenizer.eos_token_id | |
| ) | |
| text = out[0]["generated_text"] | |
| return text[len(prompt):].strip() if text.startswith(prompt) else text.strip() | |