xtreme86 commited on
Commit
15c5b99
·
1 Parent(s): 6df87c2
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -42,15 +42,16 @@ def validate_parameters(max_tokens, temperature, top_p):
42
  return False, "Error: 'Top-p' must be between 0.1 and 1.0."
43
  return True, ""
44
 
 
 
 
45
  # Load the model and tokenizer
46
  model_name = "gpt2" # Use GPT-2 model
47
 
48
  try:
49
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
50
- model = transformers.AutoModelForCausalLM.from_pretrained(
51
- model_name,
52
- device_map="auto",
53
- )
54
  model.eval()
55
  except Exception as e:
56
  logging.error(f"Failed to load model {model_name}: {e}")
@@ -76,7 +77,7 @@ def respond(message, history, persona_choice, custom_persona, max_tokens, temper
76
  logging.info(f"Received message: {safe_message}")
77
 
78
  try:
79
- input_ids = tokenizer.encode(conversation, return_tensors="pt").to(model.device)
80
 
81
  output_ids = model.generate(
82
  input_ids,
 
42
  return False, "Error: 'Top-p' must be between 0.1 and 1.0."
43
  return True, ""
44
 
45
+ # Determine the device
46
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
+
48
  # Load the model and tokenizer
49
  model_name = "gpt2" # Use GPT-2 model
50
 
51
  try:
52
  tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
53
+ model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
54
+ model.to(device)
 
 
55
  model.eval()
56
  except Exception as e:
57
  logging.error(f"Failed to load model {model_name}: {e}")
 
77
  logging.info(f"Received message: {safe_message}")
78
 
79
  try:
80
+ input_ids = tokenizer.encode(conversation, return_tensors="pt").to(device)
81
 
82
  output_ids = model.generate(
83
  input_ids,