Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -247,6 +247,13 @@ async def load_models_startup():
|
|
| 247 |
print("StoppingCriteria initialized.")
|
| 248 |
|
| 249 |
print("✅ Modelle geladen und bereit!", flush=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 250 |
|
| 251 |
@app.get("/")
|
| 252 |
def hello():
|
|
@@ -294,15 +301,20 @@ async def tts(ws: WebSocket):
|
|
| 294 |
|
| 295 |
print("Starting generation in background thread...")
|
| 296 |
await asyncio.to_thread(
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
)
|
| 307 |
print("Generation thread finished.")
|
| 308 |
|
|
|
|
| 247 |
print("StoppingCriteria initialized.")
|
| 248 |
|
| 249 |
print("✅ Modelle geladen und bereit!", flush=True)
|
| 250 |
+
print(f"Tokenizer EOS ID: {tok.eos_token_id}")
|
| 251 |
+
print(f"Model Config EOS ID: {model.config.eos_token_id}")
|
| 252 |
+
print(f"Constant EOS_TOKEN: {EOS_TOKEN}")
|
| 253 |
+
if tok.eos_token_id != EOS_TOKEN or model.config.eos_token_id != EOS_TOKEN:
|
| 254 |
+
print("⚠️ WARNING: EOS_TOKEN constant might not match model/tokenizer configuration!")
|
| 255 |
+
# Consider updating EOS_TOKEN if they differ, e.g.:
|
| 256 |
+
# EOS_TOKEN = model.config.eos_token_id
|
| 257 |
|
| 258 |
@app.get("/")
|
| 259 |
def hello():
|
|
|
|
| 301 |
|
| 302 |
print("Starting generation in background thread...")
|
| 303 |
await asyncio.to_thread(
|
| 304 |
+
model.generate,
|
| 305 |
+
input_ids=ids,
|
| 306 |
+
attention_mask=attn,
|
| 307 |
+
max_new_tokens=2500, # Keep or increase later if needed
|
| 308 |
+
logits_processor=[masker],
|
| 309 |
+
stopping_criteria=stopping_criteria,
|
| 310 |
+
# --- Changes ---
|
| 311 |
+
do_sample=True, # Enable sampling
|
| 312 |
+
temperature=0.6, # Introduce some randomness (adjust as needed)
|
| 313 |
+
top_p=0.9, # Focus sampling on more likely tokens (adjust as needed)
|
| 314 |
+
repetition_penalty=1.15, # Penalize recently generated tokens (adjust > 1.0)
|
| 315 |
+
# --- End Changes ---
|
| 316 |
+
use_cache=True,
|
| 317 |
+
streamer=streamer
|
| 318 |
)
|
| 319 |
print("Generation thread finished.")
|
| 320 |
|