Rajan Sharma commited on
Commit
8aebe10
·
verified ·
1 Parent(s): f027b47

Update local_llm.py

Browse files
Files changed (1) hide show
  1. local_llm.py +12 -22
local_llm.py CHANGED
@@ -1,4 +1,4 @@
1
- from typing import Optional, List
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS
@@ -6,7 +6,6 @@ from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS
6
  class LocalLLM:
7
  def __init__(self):
8
  self.pipe = None
9
- self.model_id = None
10
  self._load_any()
11
 
12
  def _load_any(self):
@@ -14,31 +13,22 @@ class LocalLLM:
14
  try:
15
  tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
16
  mdl = AutoModelForCausalLM.from_pretrained(
17
- mid, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
 
18
  trust_remote_code=True
19
  )
20
  self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok)
21
- self.model_id = mid
22
  return
23
  except Exception:
24
  continue
25
- self.pipe = None
26
 
27
  def chat(self, prompt: str) -> Optional[str]:
28
- if not self.pipe:
29
- return None
30
- try:
31
- out = self.pipe(
32
- prompt,
33
- max_new_tokens=LOCAL_MAX_NEW_TOKENS,
34
- do_sample=True,
35
- temperature=0.3,
36
- top_p=0.9,
37
- repetition_penalty=1.12,
38
- eos_token_id=self.pipe.tokenizer.eos_token_id
39
- )
40
- text = out[0]["generated_text"]
41
- # Return only the continuation if prompt is included
42
- return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()
43
- except Exception:
44
- return None
 
1
+ from typing import Optional
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS
 
6
  class LocalLLM:
7
  def __init__(self):
8
  self.pipe = None
 
9
  self._load_any()
10
 
11
  def _load_any(self):
 
13
  try:
14
  tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
15
  mdl = AutoModelForCausalLM.from_pretrained(
16
+ mid, device_map="auto",
17
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
18
  trust_remote_code=True
19
  )
20
  self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok)
 
21
  return
22
  except Exception:
23
  continue
 
24
 
25
  def chat(self, prompt: str) -> Optional[str]:
26
+ if not self.pipe: return None
27
+ out = self.pipe(
28
+ prompt, max_new_tokens=LOCAL_MAX_NEW_TOKENS,
29
+ do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.12,
30
+ eos_token_id=self.pipe.tokenizer.eos_token_id
31
+ )
32
+ text = out[0]["generated_text"]
33
+ return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()
34
+