Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -700,36 +700,33 @@ async def generate(
|
|
| 700 |
):
|
| 701 |
"""
|
| 702 |
Generate a text response based on the provided context and chat history.
|
| 703 |
-
|
| 704 |
-
The generation process can be customized using various parameters in the config:
|
| 705 |
-
- temperature: Controls randomness (0.0 to 2.0)
|
| 706 |
-
- max_new_tokens: Maximum length of generated text
|
| 707 |
-
- top_p: Nucleus sampling parameter
|
| 708 |
-
- top_k: Top-k sampling parameter
|
| 709 |
-
- strategy: Generation strategy to use
|
| 710 |
-
- num_samples: Number of samples for applicable strategies
|
| 711 |
-
|
| 712 |
-
Generation Strategies:
|
| 713 |
-
- default: Standard generation
|
| 714 |
-
- majority_voting: Generates multiple responses and uses the most common one
|
| 715 |
-
- best_of_n: Generates multiple responses and picks the best
|
| 716 |
-
- beam_search: Uses beam search for coherent generation
|
| 717 |
-
- dvts: Dynamic vocabulary tree search
|
| 718 |
"""
|
| 719 |
try:
|
| 720 |
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
|
| 721 |
user_input = request.messages[-1].content
|
| 722 |
-
|
|
|
|
| 723 |
config = request.config or GenerationConfig()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 724 |
|
|
|
|
| 725 |
response = await asyncio.to_thread(
|
| 726 |
generator.generate_with_context,
|
| 727 |
context=request.context or "",
|
| 728 |
user_input=user_input,
|
| 729 |
chat_history=chat_history,
|
| 730 |
-
model_kwargs=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 731 |
)
|
| 732 |
-
|
| 733 |
return GenerationResponse(
|
| 734 |
id=str(uuid.uuid4()),
|
| 735 |
content=response
|
|
|
|
| 700 |
):
|
| 701 |
"""
|
| 702 |
Generate a text response based on the provided context and chat history.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 703 |
"""
|
| 704 |
try:
|
| 705 |
chat_history = [(msg.role, msg.content) for msg in request.messages[:-1]]
|
| 706 |
user_input = request.messages[-1].content
|
| 707 |
+
|
| 708 |
+
# Extract or set defaults for additional arguments
|
| 709 |
config = request.config or GenerationConfig()
|
| 710 |
+
model_kwargs = {
|
| 711 |
+
"temperature": config.temperature if hasattr(config, "temperature") else 0.7,
|
| 712 |
+
"max_new_tokens": config.max_new_tokens if hasattr(config, "max_new_tokens") else 100,
|
| 713 |
+
# Add other model kwargs as needed
|
| 714 |
+
}
|
| 715 |
|
| 716 |
+
# Explicitly pass additional required arguments
|
| 717 |
response = await asyncio.to_thread(
|
| 718 |
generator.generate_with_context,
|
| 719 |
context=request.context or "",
|
| 720 |
user_input=user_input,
|
| 721 |
chat_history=chat_history,
|
| 722 |
+
model_kwargs=model_kwargs,
|
| 723 |
+
max_history_turns=config.max_history_turns if hasattr(config, "max_history_turns") else 3,
|
| 724 |
+
strategy=config.strategy if hasattr(config, "strategy") else "default",
|
| 725 |
+
num_samples=config.num_samples if hasattr(config, "num_samples") else 5,
|
| 726 |
+
depth=config.depth if hasattr(config, "depth") else 3,
|
| 727 |
+
breadth=config.breadth if hasattr(config, "breadth") else 2,
|
| 728 |
)
|
| 729 |
+
|
| 730 |
return GenerationResponse(
|
| 731 |
id=str(uuid.uuid4()),
|
| 732 |
content=response
|