import torch from transformers import AutoTokenizer, AutoModelForCausalLM from config import PRIMARY_LLM, FALLBACK_LLM def load_model(): try: tokenizer = AutoTokenizer.from_pretrained(PRIMARY_LLM) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( PRIMARY_LLM, device_map="auto", load_in_8bit=True ) print(f"Loaded primary model: {PRIMARY_LLM}") except Exception as e: print(f"Primary model failed: {e}") print(f"Loading fallback: {FALLBACK_LLM}") tokenizer = AutoTokenizer.from_pretrained(FALLBACK_LLM, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token model = AutoModelForCausalLM.from_pretrained( FALLBACK_LLM, trust_remote_code=True, torch_dtype=torch.float16, device_map="auto" ) return tokenizer, model tokenizer, model = load_model() def generate(prompt, max_tokens=400): inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) inputs = {k: v.to(model.device) for k, v in inputs.items()} with torch.no_grad(): output = model.generate( **inputs, max_new_tokens=max_tokens, do_sample=True, temperature=0.7, top_p=0.9 ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) if prompt in generated_text: generated_text = generated_text.replace(prompt, "").strip() return generated_text