Rajan Sharma commited on
Commit
ab8df1d
·
verified ·
1 Parent(s): 7935397

Create local_llm.py

Browse files
Files changed (1) hide show
  1. local_llm.py +44 -0
local_llm.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, List
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
+ from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS
5
+
6
+ class LocalLLM:
7
+ def __init__(self):
8
+ self.pipe = None
9
+ self.model_id = None
10
+ self._load_any()
11
+
12
+ def _load_any(self):
13
+ for mid in OPEN_LLM_CANDIDATES:
14
+ try:
15
+ tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
16
+ mdl = AutoModelForCausalLM.from_pretrained(
17
+ mid, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
18
+ trust_remote_code=True
19
+ )
20
+ self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok)
21
+ self.model_id = mid
22
+ return
23
+ except Exception:
24
+ continue
25
+ self.pipe = None
26
+
27
+ def chat(self, prompt: str) -> Optional[str]:
28
+ if not self.pipe:
29
+ return None
30
+ try:
31
+ out = self.pipe(
32
+ prompt,
33
+ max_new_tokens=LOCAL_MAX_NEW_TOKENS,
34
+ do_sample=True,
35
+ temperature=0.3,
36
+ top_p=0.9,
37
+ repetition_penalty=1.12,
38
+ eos_token_id=self.pipe.tokenizer.eos_token_id
39
+ )
40
+ text = out[0]["generated_text"]
41
+ # Return only the continuation if prompt is included
42
+ return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()
43
+ except Exception:
44
+ return None