ds
Browse files
app.py
CHANGED
|
@@ -42,15 +42,16 @@ def validate_parameters(max_tokens, temperature, top_p):
|
|
| 42 |
return False, "Error: 'Top-p' must be between 0.1 and 1.0."
|
| 43 |
return True, ""
|
| 44 |
|
|
|
|
|
|
|
|
|
|
| 45 |
# Load the model and tokenizer
|
| 46 |
model_name = "gpt2" # Use GPT-2 model
|
| 47 |
|
| 48 |
try:
|
| 49 |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
| 50 |
-
model = transformers.AutoModelForCausalLM.from_pretrained(
|
| 51 |
-
|
| 52 |
-
device_map="auto",
|
| 53 |
-
)
|
| 54 |
model.eval()
|
| 55 |
except Exception as e:
|
| 56 |
logging.error(f"Failed to load model {model_name}: {e}")
|
|
@@ -76,7 +77,7 @@ def respond(message, history, persona_choice, custom_persona, max_tokens, temper
|
|
| 76 |
logging.info(f"Received message: {safe_message}")
|
| 77 |
|
| 78 |
try:
|
| 79 |
-
input_ids = tokenizer.encode(conversation, return_tensors="pt").to(
|
| 80 |
|
| 81 |
output_ids = model.generate(
|
| 82 |
input_ids,
|
|
|
|
| 42 |
return False, "Error: 'Top-p' must be between 0.1 and 1.0."
|
| 43 |
return True, ""
|
| 44 |
|
| 45 |
+
# Determine the device
|
| 46 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 47 |
+
|
| 48 |
# Load the model and tokenizer
|
| 49 |
model_name = "gpt2" # Use GPT-2 model
|
| 50 |
|
| 51 |
try:
|
| 52 |
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
|
| 53 |
+
model = transformers.AutoModelForCausalLM.from_pretrained(model_name)
|
| 54 |
+
model.to(device)
|
|
|
|
|
|
|
| 55 |
model.eval()
|
| 56 |
except Exception as e:
|
| 57 |
logging.error(f"Failed to load model {model_name}: {e}")
|
|
|
|
| 77 |
logging.info(f"Received message: {safe_message}")
|
| 78 |
|
| 79 |
try:
|
| 80 |
+
input_ids = tokenizer.encode(conversation, return_tensors="pt").to(device)
|
| 81 |
|
| 82 |
output_ids = model.generate(
|
| 83 |
input_ids,
|