kdevoe commited on
Commit
3f9b161
·
verified ·
1 Parent(s): 6848d2f

Fixing device loading issue

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -17,7 +17,7 @@ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
  loaded_models = {
18
  "DialoGPT-med-FT": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
19
  }
20
- loaded_models["DialoGPT-med-FT"].load_state_dict(torch.load(model_names["DialoGPT-med-FT"]))
21
  loaded_models["DialoGPT-med-FT"].to(device)
22
 
23
  loaded_models["DialoGPT-medium"] = AutoModelForCausalLM.from_pretrained(model_names["DialoGPT-medium"]).to(device)
@@ -40,11 +40,11 @@ def respond(
40
  input_text += f"User: {message}\nAssistant:"
41
 
42
  # Tokenize the input text using the shared tokenizer
43
- inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(model.device)
44
 
45
  # Generate the response using the selected DialoGPT model
46
  output_tokens = model.generate(
47
- inputs["input_ids"],
48
  max_length=len(inputs["input_ids"][0]) + max_tokens,
49
  temperature=temperature,
50
  top_p=top_p,
 
17
  loaded_models = {
18
  "DialoGPT-med-FT": AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-medium")
19
  }
20
+ loaded_models["DialoGPT-med-FT"].load_state_dict(torch.load(model_names["DialoGPT-med-FT"], map_location=device))
21
  loaded_models["DialoGPT-med-FT"].to(device)
22
 
23
  loaded_models["DialoGPT-medium"] = AutoModelForCausalLM.from_pretrained(model_names["DialoGPT-medium"]).to(device)
 
40
  input_text += f"User: {message}\nAssistant:"
41
 
42
  # Tokenize the input text using the shared tokenizer
43
+ inputs = tokenizer(input_text, return_tensors="pt", truncation=True).to(device)
44
 
45
  # Generate the response using the selected DialoGPT model
46
  output_tokens = model.generate(
47
+ inputs["input_ids"].to(device),
48
  max_length=len(inputs["input_ids"][0]) + max_tokens,
49
  temperature=temperature,
50
  top_p=top_p,