Spaces:
Sleeping
Sleeping
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| from config import PRIMARY_LLM, FALLBACK_LLM | |
| def load_model(): | |
| try: | |
| tokenizer = AutoTokenizer.from_pretrained(PRIMARY_LLM) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| PRIMARY_LLM, | |
| device_map="auto", | |
| load_in_8bit=True | |
| ) | |
| print(f"Loaded primary model: {PRIMARY_LLM}") | |
| except Exception as e: | |
| print(f"Primary model failed: {e}") | |
| print(f"Loading fallback: {FALLBACK_LLM}") | |
| tokenizer = AutoTokenizer.from_pretrained(FALLBACK_LLM, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| model = AutoModelForCausalLM.from_pretrained( | |
| FALLBACK_LLM, | |
| trust_remote_code=True, | |
| torch_dtype=torch.float16, | |
| device_map="auto" | |
| ) | |
| return tokenizer, model | |
| tokenizer, model = load_model() | |
| def generate(prompt, max_tokens=400): | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9 | |
| ) | |
| generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
| if prompt in generated_text: | |
| generated_text = generated_text.replace(prompt, "").strip() | |
| return generated_text |