Medica_DecisionSupportAI / local_llm.py
Rajan Sharma
Update local_llm.py
8aebe10 verified
raw
history blame
1.3 kB
from typing import Optional
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS
class LocalLLM:
def __init__(self):
self.pipe = None
self._load_any()
def _load_any(self):
for mid in OPEN_LLM_CANDIDATES:
try:
tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
mdl = AutoModelForCausalLM.from_pretrained(
mid, device_map="auto",
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
)
self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok)
return
except Exception:
continue
def chat(self, prompt: str) -> Optional[str]:
if not self.pipe: return None
out = self.pipe(
prompt, max_new_tokens=LOCAL_MAX_NEW_TOKENS,
do_sample=True, temperature=0.3, top_p=0.9, repetition_penalty=1.12,
eos_token_id=self.pipe.tokenizer.eos_token_id
)
text = out[0]["generated_text"]
return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()