Spaces:
Running on Zero
Running on Zero
| import spaces | |
| import torch | |
| from transformers import AutoTokenizer, BitsAndBytesConfig | |
| from peft import AutoPeftModelForCausalLM | |
| MODEL_ID = "Playingyoyo/aLLoyM" | |
| model = None | |
| tokenizer = None | |
| def load_model(): | |
| global model, tokenizer | |
| print(f"Loading Model: {MODEL_ID}...") | |
| quantization_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| bnb_4bit_quant_type="nf4" | |
| ) | |
| try: | |
| model = AutoPeftModelForCausalLM.from_pretrained( | |
| MODEL_ID, | |
| quantization_config=quantization_config, | |
| device_map="auto" | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) | |
| if tokenizer.pad_token_id is None: | |
| tokenizer.pad_token_id = tokenizer.eos_token_id | |
| model.eval() | |
| print("Model loaded successfully!") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| def run_inference(user_input, system_instruction, history): | |
| global model, tokenizer | |
| # historyがNoneなら空リストにする | |
| if history is None: | |
| history = [] | |
| if model is None: | |
| load_model() | |
| if model is None: | |
| # エラー時も辞書形式で追加 | |
| history.append({"role": "user", "content": user_input}) | |
| history.append({"role": "assistant", "content": "Error: Model failed to load."}) | |
| return history | |
| # Alpaca形式プロンプト | |
| prompt = f"""### Instruction: | |
| {system_instruction} | |
| ### Input: | |
| {user_input} | |
| ### Output: | |
| """ | |
| inputs = tokenizer( | |
| [prompt], | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=2048 | |
| ).to("cuda") | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=512, | |
| use_cache=True, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id | |
| ) | |
| full_output = tokenizer.batch_decode(outputs, skip_special_tokens=True)[0] | |
| if "### Output:" in full_output: | |
| generated_response = full_output.split("### Output:")[1].strip() | |
| else: | |
| generated_response = full_output.replace(prompt, "").strip() | |
| # 【重要】エラーメッセージに従い、辞書形式 (Messages format) で返す | |
| history.append({"role": "user", "content": user_input}) | |
| history.append({"role": "assistant", "content": generated_response}) | |
| return history |