Fix response extraction and increase token generation
Browse files
utils.py
CHANGED
|
@@ -286,7 +286,11 @@ class SuggestionGenerator:
|
|
| 286 |
test_prompt = "I am Will. My son Billy asked about football. I respond:"
|
| 287 |
print(f"Testing model with prompt: {test_prompt}")
|
| 288 |
response = self.generator(test_prompt, max_new_tokens=30, do_sample=True)
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
print(f"Test response: {result}")
|
| 291 |
return f"Model test successful: {result}"
|
| 292 |
except Exception as e:
|
|
@@ -468,7 +472,7 @@ My conversation starter to {name}:"""
|
|
| 468 |
# Use max_new_tokens instead of max_length to avoid the error
|
| 469 |
response = self.generator(
|
| 470 |
prompt,
|
| 471 |
-
max_new_tokens=
|
| 472 |
temperature=temperature,
|
| 473 |
do_sample=True,
|
| 474 |
top_p=0.92,
|
|
@@ -477,9 +481,19 @@ My conversation starter to {name}:"""
|
|
| 477 |
truncation=False,
|
| 478 |
)
|
| 479 |
# Extract only the generated part, not the prompt
|
| 480 |
-
|
| 481 |
-
print(f"
|
| 482 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 483 |
except Exception as e:
|
| 484 |
print(f"Error generating suggestion: {e}")
|
| 485 |
return "Could not generate a suggestion. Please try again."
|
|
|
|
| 286 |
test_prompt = "I am Will. My son Billy asked about football. I respond:"
|
| 287 |
print(f"Testing model with prompt: {test_prompt}")
|
| 288 |
response = self.generator(test_prompt, max_new_tokens=30, do_sample=True)
|
| 289 |
+
full_text = response[0]["generated_text"]
|
| 290 |
+
if len(test_prompt) < len(full_text):
|
| 291 |
+
result = full_text[len(test_prompt) :]
|
| 292 |
+
else:
|
| 293 |
+
result = "No additional text generated"
|
| 294 |
print(f"Test response: {result}")
|
| 295 |
return f"Model test successful: {result}"
|
| 296 |
except Exception as e:
|
|
|
|
| 472 |
# Use max_new_tokens instead of max_length to avoid the error
|
| 473 |
response = self.generator(
|
| 474 |
prompt,
|
| 475 |
+
max_new_tokens=100, # Generate more tokens to ensure we get a response
|
| 476 |
temperature=temperature,
|
| 477 |
do_sample=True,
|
| 478 |
top_p=0.92,
|
|
|
|
| 481 |
truncation=False,
|
| 482 |
)
|
| 483 |
# Extract only the generated part, not the prompt
|
| 484 |
+
full_text = response[0]["generated_text"]
|
| 485 |
+
print(f"Full generated text length: {len(full_text)}")
|
| 486 |
+
print(f"Prompt length: {len(prompt)}")
|
| 487 |
+
|
| 488 |
+
# Make sure we're not trying to slice beyond the text length
|
| 489 |
+
if len(prompt) < len(full_text):
|
| 490 |
+
result = full_text[len(prompt) :]
|
| 491 |
+
print(f"Generated response: {result}")
|
| 492 |
+
return result.strip()
|
| 493 |
+
else:
|
| 494 |
+
# If the model didn't generate anything beyond the prompt
|
| 495 |
+
print("Model didn't generate text beyond prompt")
|
| 496 |
+
return "I'm thinking about what to say..."
|
| 497 |
except Exception as e:
|
| 498 |
print(f"Error generating suggestion: {e}")
|
| 499 |
return "Could not generate a suggestion. Please try again."
|