João Lima
fixing stuffs
ca69070
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
from config import PRIMARY_LLM, FALLBACK_LLM
def load_model():
try:
tokenizer = AutoTokenizer.from_pretrained(PRIMARY_LLM)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
PRIMARY_LLM,
device_map="auto",
load_in_8bit=True
)
print(f"Loaded primary model: {PRIMARY_LLM}")
except Exception as e:
print(f"Primary model failed: {e}")
print(f"Loading fallback: {FALLBACK_LLM}")
tokenizer = AutoTokenizer.from_pretrained(FALLBACK_LLM, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(
FALLBACK_LLM,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto"
)
return tokenizer, model
tokenizer, model = load_model()
def generate(prompt, max_tokens=400):
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048)
inputs = {k: v.to(model.device) for k, v in inputs.items()}
with torch.no_grad():
output = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True,
temperature=0.7,
top_p=0.9
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
if prompt in generated_text:
generated_text = generated_text.replace(prompt, "").strip()
return generated_text