Medica_DecisionSupportAI / local_llm.py
Rajan Sharma
Create local_llm.py
ab8df1d verified
raw
history blame
1.61 kB
from typing import Optional, List
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from settings import OPEN_LLM_CANDIDATES, LOCAL_MAX_NEW_TOKENS
class LocalLLM:
def __init__(self):
self.pipe = None
self.model_id = None
self._load_any()
def _load_any(self):
for mid in OPEN_LLM_CANDIDATES:
try:
tok = AutoTokenizer.from_pretrained(mid, trust_remote_code=True)
mdl = AutoModelForCausalLM.from_pretrained(
mid, device_map="auto", torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
trust_remote_code=True
)
self.pipe = pipeline("text-generation", model=mdl, tokenizer=tok)
self.model_id = mid
return
except Exception:
continue
self.pipe = None
def chat(self, prompt: str) -> Optional[str]:
if not self.pipe:
return None
try:
out = self.pipe(
prompt,
max_new_tokens=LOCAL_MAX_NEW_TOKENS,
do_sample=True,
temperature=0.3,
top_p=0.9,
repetition_penalty=1.12,
eos_token_id=self.pipe.tokenizer.eos_token_id
)
text = out[0]["generated_text"]
# Return only the continuation if prompt is included
return text[len(prompt):].strip() if text.startswith(prompt) else text.strip()
except Exception:
return None