| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| MODEL_PATH = "SmallDront-20m/" | |
| TEMPERATURE = 0.5 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| def load_model_and_tokenizer(model_path): | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=False) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| model_path, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=False | |
| ) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| return model, tokenizer | |
| def generate_response(model, tokenizer, prompt, temperature=0.4, max_new_tokens=64): | |
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=2048) | |
| inputs = {k: v.to(model.device) for k, v in inputs.items()} | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| **inputs, | |
| max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| do_sample=True, | |
| top_p=0.95, | |
| repetition_penalty=1.1, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id | |
| ) | |
| input_length = inputs["input_ids"].shape[1] | |
| new_tokens = output_ids[0][input_length:] | |
| return tokenizer.decode(new_tokens, skip_special_tokens=True).strip() | |
| def interactive_chat(model, tokenizer, temperature): | |
| print(f"Chat with model (temp={temperature}) - type 'exit' or 'quit' to stop") | |
| while True: | |
| user_input = input("\nYou: ").strip() | |
| if user_input.lower() in ["exit", "quit"]: | |
| print("Goodbye!") | |
| break | |
| if not user_input: | |
| continue | |
| response = generate_response(model, tokenizer, f"<|user|>{user_input}<|assistant|>", temperature=temperature) | |
| print(f"Assistant: {response}") | |
| if __name__ == "__main__": | |
| model, tokenizer = load_model_and_tokenizer(MODEL_PATH) | |
| interactive_chat(model, tokenizer, temperature=TEMPERATURE) |