Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,1491 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import gc
|
| 3 |
+
import json
|
| 4 |
+
import random
|
| 5 |
+
import torch
|
| 6 |
+
import asyncio
|
| 7 |
+
import logging
|
| 8 |
+
import time
|
| 9 |
+
from typing import List, Dict, Any, Optional, Union, AsyncGenerator, Tuple
|
| 10 |
+
from fastapi import FastAPI, HTTPException, Query, Request, Depends, status
|
| 11 |
+
from fastapi.responses import StreamingResponse, PlainTextResponse, HTMLResponse, JSONResponse
|
| 12 |
+
from fastapi.security import APIKeyHeader
|
| 13 |
+
from pydantic import BaseModel, Field, ValidationError, validator
|
| 14 |
+
from transformers import (
|
| 15 |
+
AutoConfig, AutoModelForCausalLM, AutoTokenizer,
|
| 16 |
+
GenerationConfig, LogitsProcessorList,
|
| 17 |
+
MinLengthLogitsProcessor, MaxLengthCriteria,
|
| 18 |
+
StoppingCriteriaList, StoppingCriteria
|
| 19 |
+
)
|
| 20 |
+
import uvicorn
|
| 21 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 22 |
+
import math
|
| 23 |
+
import torch.nn.functional as F
|
| 24 |
+
import copy
|
| 25 |
+
|
| 26 |
+
app = FastAPI(title="Chatbot Profesional API", version="1.0.0")
|
| 27 |
+
|
| 28 |
+
class StopSequenceCriteria(StoppingCriteria):
|
| 29 |
+
def __init__(self, stop_sequences: List[str], tokenizer: AutoTokenizer):
|
| 30 |
+
self.tokenizer = tokenizer
|
| 31 |
+
self.stop_sequences_text = []
|
| 32 |
+
self.stop_sequence_ids = []
|
| 33 |
+
for seq in stop_sequences:
|
| 34 |
+
if seq:
|
| 35 |
+
encoded_ids = tokenizer.encode(seq, add_special_tokens=False)
|
| 36 |
+
decoded_text = tokenizer.decode(encoded_ids, skip_special_tokens=True)
|
| 37 |
+
if decoded_text:
|
| 38 |
+
self.stop_sequences_text.append(decoded_text)
|
| 39 |
+
self.stop_sequence_ids.append(encoded_ids)
|
| 40 |
+
|
| 41 |
+
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool:
|
| 42 |
+
if not self.stop_sequence_ids:
|
| 43 |
+
return False
|
| 44 |
+
|
| 45 |
+
input_ids_list = input_ids[0].tolist()
|
| 46 |
+
|
| 47 |
+
for stop_seq_ids in self.stop_sequence_ids:
|
| 48 |
+
stop_len = len(stop_seq_ids)
|
| 49 |
+
if len(input_ids_list) >= stop_len:
|
| 50 |
+
if input_ids_list[-stop_len:] == stop_seq_ids:
|
| 51 |
+
return True
|
| 52 |
+
|
| 53 |
+
check_tail_len = 50
|
| 54 |
+
if self.stop_sequence_ids:
|
| 55 |
+
max_stop_seq_token_len = max((len(seq) for seq in self.stop_sequence_ids), default=0)
|
| 56 |
+
check_tail_len = max(check_tail_len, max_stop_seq_token_len + 10)
|
| 57 |
+
|
| 58 |
+
tail_ids = input_ids_list[-min(check_tail_len, len(input_ids_list)):]
|
| 59 |
+
tail_text = self.tokenizer.decode(tail_ids, skip_special_tokens=True)
|
| 60 |
+
|
| 61 |
+
for stop_seq_text in self.stop_sequences_text:
|
| 62 |
+
if stop_seq_text and stop_seq_text in tail_text:
|
| 63 |
+
return True
|
| 64 |
+
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
logging.getLogger("uvicorn").handlers.clear()
|
| 68 |
+
logging.getLogger("uvicorn.error").handlers.clear()
|
| 69 |
+
logging.getLogger("uvicorn.access").handlers.clear()
|
| 70 |
+
logging.getLogger("uvicorn").propagate = False
|
| 71 |
+
logging.getLogger("uvicorn.error").propagate = False
|
| 72 |
+
logging.getLogger("uvicorn.access").propagate = False
|
| 73 |
+
logging.getLogger("uvicorn").setLevel(logging.CRITICAL)
|
| 74 |
+
logging.getLogger("uvicorn.error").setLevel(logging.CRITICAL)
|
| 75 |
+
logging.getLogger("uvicorn.access").setLevel(logging.CRITICAL)
|
| 76 |
+
logging.getLogger("fastapi").setLevel(logging.CRITICAL)
|
| 77 |
+
logging.getLogger("transformers").setLevel(logging.CRITICAL)
|
| 78 |
+
logging.getLogger().handlers.clear()
|
| 79 |
+
logging.getLogger().addHandler(logging.NullHandler())
|
| 80 |
+
|
| 81 |
+
DEFAULT_MODEL_NAME = "hghghgkskdmskdms/xddd"
|
| 82 |
+
MODEL_NAME = os.environ.get("MODEL_NAME", DEFAULT_MODEL_NAME)
|
| 83 |
+
SYSTEM_PROMPT = os.environ.get("SYSTEM_PROMPT", "Eres un asistente profesional y servicial.")
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
MAX_CONTEXT_TOKENS = int(os.environ.get("MAX_CONTEXT_TOKENS", 1024))
|
| 87 |
+
if MAX_CONTEXT_TOKENS <= 0:
|
| 88 |
+
raise ValueError("MAX_CONTEXT_TOKENS must be positive.")
|
| 89 |
+
except (ValueError, TypeError) as e:
|
| 90 |
+
logging.error(f"Invalid MAX_CONTEXT_TOKENS environment variable: {os.environ.get('MAX_CONTEXT_TOKENS')}. Using default 1024. Error: {e}")
|
| 91 |
+
MAX_CONTEXT_TOKENS = 1024
|
| 92 |
+
|
| 93 |
+
try:
|
| 94 |
+
MAX_GENERATION_TOKENS = int(os.environ.get("MAX_GENERATION_TOKENS", 512))
|
| 95 |
+
if MAX_GENERATION_TOKENS <= 0:
|
| 96 |
+
raise ValueError("MAX_GENERATION_TOKENS must be positive.")
|
| 97 |
+
except (ValueError, TypeError) as e:
|
| 98 |
+
logging.error(f"Invalid MAX_GENERATION_TOKENS environment variable: {os.environ.get('MAX_GENERATION_TOKENS')}. Using default 512. Error: {e}")
|
| 99 |
+
MAX_GENERATION_TOKENS = 512
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
MAX_CONCURRENT_GENERATIONS = int(os.environ.get("MAX_CONCURRENT_GENERATIONS", 4))
|
| 103 |
+
if MAX_CONCURRENT_GENERATIONS <= 0:
|
| 104 |
+
raise ValueError("MAX_CONCURRENT_GENERATIONS must be positive.")
|
| 105 |
+
except (ValueError, TypeError) as e:
|
| 106 |
+
logging.error(f"Invalid MAX_CONCURRENT_GENERATIONS environment variable: {os.environ.get('MAX_CONCURRENT_GENERATIONS')}. Using default 4. Error: {e}")
|
| 107 |
+
MAX_CONCURRENT_GENERATIONS = 4
|
| 108 |
+
|
| 109 |
+
TRUST_REMOTE_CODE_ENV = os.environ.get("TRUST_REMOTE_CODE", "false").lower() == "true"
|
| 110 |
+
TRUST_REMOTE_CODE = TRUST_REMOTE_CODE_ENV or (MODEL_NAME == DEFAULT_MODEL_NAME)
|
| 111 |
+
ENABLE_FLASH_ATTENTION_2 = os.environ.get("ENABLE_FLASH_ATTENTION_2", "false").lower() == "true"
|
| 112 |
+
TORCH_DTYPE_STR = os.environ.get("TORCH_DTYPE", "float32")
|
| 113 |
+
TORCH_DTYPE = getattr(torch, TORCH_DTYPE_STR.lower(), torch.float32)
|
| 114 |
+
if TORCH_DTYPE != torch.float32:
|
| 115 |
+
logging.warning(f"Requested dtype {TORCH_DTYPE_STR} might not be fully performant on CPU. Using float32.")
|
| 116 |
+
TORCH_DTYPE = torch.float32
|
| 117 |
+
|
| 118 |
+
API_KEY = os.environ.get("API_KEY")
|
| 119 |
+
|
| 120 |
+
global_model = None
|
| 121 |
+
global_tokenizer = None
|
| 122 |
+
global_tokens: Dict[str, Optional[int]] = {}
|
| 123 |
+
executor = ThreadPoolExecutor(max_workers=MAX_CONCURRENT_GENERATIONS)
|
| 124 |
+
generation_semaphore = asyncio.Semaphore(MAX_CONCURRENT_GENERATIONS)
|
| 125 |
+
|
| 126 |
+
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
| 127 |
+
|
| 128 |
+
async def get_api_key(api_key: str = Depends(api_key_header)):
|
| 129 |
+
if API_KEY is None:
|
| 130 |
+
return
|
| 131 |
+
if api_key is None or api_key != API_KEY:
|
| 132 |
+
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid or missing API Key")
|
| 133 |
+
return api_key
|
| 134 |
+
|
| 135 |
+
class GenerateRequest(BaseModel):
|
| 136 |
+
input_text: str = Field(..., description="The input text from the user.", examples=["Hola, ¿cómo estás?"])
|
| 137 |
+
history: Optional[List[Dict[str, str]]] = Field(None, description="Conversation history.", examples=[[{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}, {"role": "assistant", "content": "The capital of France is Paris."}]])
|
| 138 |
+
stream: bool = Field(True, description="Whether to stream the response.")
|
| 139 |
+
temperature: float = Field(1.0, ge=0.0, le=2.0, description="Controls the randomness.")
|
| 140 |
+
top_k: int = Field(50, ge=0, description="Top-k filtering.")
|
| 141 |
+
top_p: float = Field(1.0, ge=0.0, le=1.0, description="Top-p (nucleus) filtering.")
|
| 142 |
+
repetition_penalty: float = Field(1.0, ge=0.0, description="Repetition penalty.")
|
| 143 |
+
frequency_penalty: float = Field(0.0, ge=0.0, description="Frequency penalty.")
|
| 144 |
+
presence_penalty: float = Field(0.0, ge=0.0, description="Presence penalty.")
|
| 145 |
+
num_beams: int = Field(1, ge=1, description="Number of beams for beam search.")
|
| 146 |
+
length_penalty: float = Field(1.0, ge=0.0, description="Length penalty.")
|
| 147 |
+
no_repeat_ngram_size: int = Field(0, ge=0, description="No repeat ngram size.")
|
| 148 |
+
early_stopping: bool = Field(False, description="Early stopping for beam search.")
|
| 149 |
+
do_sample: bool = Field(True, description="Whether to use sampling.")
|
| 150 |
+
use_mirostat: bool = Field(False, description="Whether to use Mirostat sampling.")
|
| 151 |
+
mirostat_tau: float = Field(5.0, ge=0.0, description="Mirostat tau.")
|
| 152 |
+
mirostat_eta: float = Field(0.1, ge=0.0, description="Mirostat eta.")
|
| 153 |
+
max_new_tokens: int = Field(MAX_GENERATION_TOKENS, ge=1, description="Max new tokens.")
|
| 154 |
+
system_prompt: Optional[str] = Field(None, description="Override the default system prompt.")
|
| 155 |
+
seed: Optional[int] = Field(None, description="Random seed.")
|
| 156 |
+
stop_sequences: Optional[List[str]] = Field(None, description="List of stop strings.", examples=[[".", "\nUsuario:"]])
|
| 157 |
+
tokenize_only: bool = Field(False, description="If true, only tokenize input.")
|
| 158 |
+
strip_trailing_whitespace: bool = Field(False, description="Strip trailing whitespace.")
|
| 159 |
+
remove_incomplete_sentences: bool = Field(False, description="Remove incomplete last sentence.")
|
| 160 |
+
num_return_sequences: int = Field(1, ge=1, le=5, description="Number of sequences to return (non-streaming).")
|
| 161 |
+
bad_words_ids: Optional[List[List[int]]] = Field(None, description="List of bad word token ids.", examples=[[[32000], [32001]]])
|
| 162 |
+
forced_bos_token_id: Optional[int] = Field(None, description="Forced BOS token id.")
|
| 163 |
+
forced_eos_token_id: Optional[int] = Field(None, description="Forced EOS token id.")
|
| 164 |
+
renormalize_logits: Optional[bool] = Field(None, description="Renormalize logits.")
|
| 165 |
+
suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress.")
|
| 166 |
+
begin_suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress at beginning.")
|
| 167 |
+
end_suppress_tokens: Optional[List[int]] = Field(None, description="Tokens to suppress at end.")
|
| 168 |
+
encoder_no_repeat_ngram_size: int = Field(0, ge=0, description="Encoder no repeat ngram size.")
|
| 169 |
+
min_length: int = Field(0, ge=0, description="Minimum total length.")
|
| 170 |
+
max_length: Optional[int] = Field(None, description="Maximum total length.")
|
| 171 |
+
exponential_decay_length_penalty: Optional[Tuple[float, int, float]] = Field(None, description="Exponential decay length penalty.")
|
| 172 |
+
use_cache: bool = Field(True, description="Use cache.")
|
| 173 |
+
typical_p: float = Field(1.0, ge=0.0, le=1.0, description="Typical P sampling.")
|
| 174 |
+
epsilon_cutoff: float = Field(0.0, ge=0.0, description="Epsilon cutoff for LTS.")
|
| 175 |
+
eta_cutoff: float = Field(0.0, ge=0.0, description="Eta cutoff for LTS.")
|
| 176 |
+
temperature_cutoff: Optional[float] = Field(None, ge=0.0, description="Temperature cutoff.")
|
| 177 |
+
encoder_repetition_penalty: float = Field(1.0, ge=0.0, description="Encoder repetition penalty.")
|
| 178 |
+
max_time: Optional[float] = Field(None, ge=0.0, description="Maximum time in seconds.")
|
| 179 |
+
output_watermark: bool = Field(False, description="Output watermark.")
|
| 180 |
+
remove_input_from_output: bool = Field(False, description="Remove input from output.")
|
| 181 |
+
eos_token_id_override: Optional[int] = Field(None, description="Override EOS token id.")
|
| 182 |
+
pad_token_id_override: Optional[int] = Field(None, description="Override PAD token id.")
|
| 183 |
+
bos_token_id_override: Optional[int] = Field(None, description="Override BOS token id.")
|
| 184 |
+
repetition_penalty_range: Optional[int] = Field(None, ge=0, description="Repetition penalty range.")
|
| 185 |
+
diversity_penalty: float = Field(0.0, ge=0.0, description="Diversity penalty for diverse beam search.")
|
| 186 |
+
num_beam_groups: int = Field(1, ge=1, description="Number of beam groups for diverse beam search.")
|
| 187 |
+
return_dict_in_generate: bool = Field(False, description="Return dictionary from generate.")
|
| 188 |
+
output_attentions: bool = Field(False, description="Output attentions.")
|
| 189 |
+
output_hidden_states: bool = Field(False, description="Output hidden states.")
|
| 190 |
+
output_scores: bool = Field(False, description="Output scores.")
|
| 191 |
+
return_token_logprobs: bool = Field(False, description="Return token logprobs in stream.")
|
| 192 |
+
return_text_from_sequence: bool = Field(True, description="Decode generated sequence to text.")
|
| 193 |
+
length_normalization_factor: Optional[float] = Field(None, description="Length normalization factor for beam search.")
|
| 194 |
+
min_new_tokens: int = Field(0, ge=0, description="Minimum number of new tokens.")
|
| 195 |
+
do_normalize_logits: bool = Field(False, description="Normalize logits.")
|
| 196 |
+
return_generation_inputs: bool = Field(False, description="Return generation inputs.")
|
| 197 |
+
return_unused_generate_parameters: bool = Field(False, description="Return unused generate parameters.")
|
| 198 |
+
use_fast_tokenizer: bool = Field(True, description="Use fast tokenizer if available.")
|
| 199 |
+
model_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional model kwargs for generate.")
|
| 200 |
+
tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs for encode.")
|
| 201 |
+
return_only_text: bool = Field(False, description="If true, only return the generated text.")
|
| 202 |
+
|
| 203 |
+
@validator('stop_sequences')
|
| 204 |
+
def validate_stop_sequences(cls, v):
|
| 205 |
+
if v is not None:
|
| 206 |
+
if not all(isinstance(seq, str) for seq in v):
|
| 207 |
+
raise ValueError('Each stop sequence must be a string')
|
| 208 |
+
return v
|
| 209 |
+
|
| 210 |
+
@validator('bad_words_ids')
|
| 211 |
+
def validate_bad_words_ids(cls, v):
|
| 212 |
+
if v is not None:
|
| 213 |
+
if not all(isinstance(word_id_list, list) and all(isinstance(token_id, int) for token_id in word_id_list) for word_id_list in v):
|
| 214 |
+
raise ValueError('bad_words_ids must be a list of lists of integers')
|
| 215 |
+
return v
|
| 216 |
+
|
| 217 |
+
@validator('exponential_decay_length_penalty')
|
| 218 |
+
def validate_exponential_decay_length_penalty(cls, v):
|
| 219 |
+
if v is not None:
|
| 220 |
+
if not (isinstance(v, (list, tuple)) and len(v) == 3 and
|
| 221 |
+
isinstance(v[0], (int, float)) and v[0] > 0 and
|
| 222 |
+
isinstance(v[1], int) and v[1] >= 0 and
|
| 223 |
+
isinstance(v[2], (int, float))):
|
| 224 |
+
raise ValueError('exponential_decay_length_penalty must be a tuple/list of 3 numbers (decay_factor, start_index, threshold)')
|
| 225 |
+
return v
|
| 226 |
+
|
| 227 |
+
class TokenizeRequest(BaseModel):
|
| 228 |
+
text: Union[str, List[str]] = Field(..., description="Text or list of texts to tokenize.")
|
| 229 |
+
add_special_tokens: bool = Field(True, description="Whether to add special tokens.")
|
| 230 |
+
is_split_into_words: bool = Field(False, description="Whether the input text is pre-tokenized.")
|
| 231 |
+
return_token_type_ids: bool = Field(False, description="Whether to return token type IDs.")
|
| 232 |
+
padding: Union[bool, str] = Field(False, description="Enable padding.")
|
| 233 |
+
truncation: Union[bool, str] = Field(False, description="Enable truncation.")
|
| 234 |
+
max_length: Optional[int] = Field(None, ge=1, description="Maximum length for padding and truncation.")
|
| 235 |
+
return_tensors: Optional[str] = Field(None, description="The type of tensors to return.")
|
| 236 |
+
return_attention_mask: Optional[bool] = Field(None, description="Whether to return the attention mask.")
|
| 237 |
+
return_offsets_mapping: Optional[bool] = Field(None, description="Whether to return offsets mapping.")
|
| 238 |
+
return_length: Optional[bool] = Field(None, description="Whether to return the length.")
|
| 239 |
+
verbose: bool = Field(False, description="Verbose tokenizer output.")
|
| 240 |
+
tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs.")
|
| 241 |
+
|
| 242 |
+
class DecodeRequest(BaseModel):
|
| 243 |
+
token_ids: List[int] = Field(..., description="List of token IDs to decode.", examples=[[1, 2, 3]])
|
| 244 |
+
skip_special_tokens: bool = Field(True, description="Skip special tokens.")
|
| 245 |
+
clean_up_tokenization_spaces: bool = Field(True, description="Clean up spaces.")
|
| 246 |
+
decode_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional decode kwargs.")
|
| 247 |
+
|
| 248 |
+
class SystemPromptUpdateRequest(BaseModel):
|
| 249 |
+
system_prompt: str = Field(..., description="The new global system prompt.")
|
| 250 |
+
|
| 251 |
+
class ModelReloadRequest(BaseModel):
|
| 252 |
+
model_name: Optional[str] = Field(None, description="New model name.")
|
| 253 |
+
trust_remote_code: Optional[bool] = Field(None, description="Override trust_remote_code.")
|
| 254 |
+
enable_flash_attention_2: Optional[bool] = Field(None, description="Override enable_flash_attention_2.")
|
| 255 |
+
torch_dtype: Optional[str] = Field(None, description="Override torch_dtype.")
|
| 256 |
+
model_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional model kwargs for from_pretrained().")
|
| 257 |
+
tokenizer_kwargs: Optional[Dict[str, Any]] = Field(None, description="Additional tokenizer kwargs for from_pretrained().")
|
| 258 |
+
|
| 259 |
+
def format_conversation(input_text: str, history: Optional[List[Dict[str, str]]], system_prompt: Optional[str]) -> str:
|
| 260 |
+
full_history: List[Dict[str, str]] = []
|
| 261 |
+
used_system_prompt = system_prompt if system_prompt is not None else SYSTEM_PROMPT
|
| 262 |
+
if not history or history[0].get("role") != "system" or history[0].get("content") != used_system_prompt:
|
| 263 |
+
full_history.append({"role": "system", "content": used_system_prompt})
|
| 264 |
+
if history:
|
| 265 |
+
full_history.extend(history)
|
| 266 |
+
if not full_history or full_history[-1].get("role") != "user" or full_history[-1].get("content") != input_text:
|
| 267 |
+
full_history.append({"role": "user", "content": input_text})
|
| 268 |
+
|
| 269 |
+
if global_tokenizer and hasattr(global_tokenizer, 'apply_chat_template') and global_tokenizer.chat_template:
|
| 270 |
+
try:
|
| 271 |
+
return global_tokenizer.apply_chat_template(full_history, tokenize=False, add_generation_prompt=True)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
logging.error(f"Failed to apply chat template: {e}. Falling back to manual formatting.")
|
| 274 |
+
pass
|
| 275 |
+
formatted_text = ""
|
| 276 |
+
for i, message in enumerate(full_history):
|
| 277 |
+
if i == 0 and message["role"] == "system" and len(full_history) > 1 and full_history[1].get("role") == "system":
|
| 278 |
+
continue
|
| 279 |
+
if message["role"] == "system":
|
| 280 |
+
formatted_text += f"{message['content'].strip()}\n\n"
|
| 281 |
+
elif message["role"] == "user":
|
| 282 |
+
formatted_text += f"Usuario: {message['content'].strip()}\n"
|
| 283 |
+
elif message["role"] == "assistant":
|
| 284 |
+
formatted_text += f"Bot: {message['content'].strip()}\n"
|
| 285 |
+
if not formatted_text.endswith("Bot:"):
|
| 286 |
+
formatted_text += "Bot:"
|
| 287 |
+
return formatted_text.strip()
|
| 288 |
+
|
| 289 |
+
def truncate_encoded_ids(input_ids: torch.Tensor, max_length: int) -> torch.Tensor:
|
| 290 |
+
if input_ids.shape[-1] > max_length:
|
| 291 |
+
return input_ids[:, -max_length:]
|
| 292 |
+
return input_ids
|
| 293 |
+
|
| 294 |
+
def apply_seed(seed: Optional[int]):
|
| 295 |
+
if seed is not None:
|
| 296 |
+
torch.manual_seed(seed)
|
| 297 |
+
random.seed(seed)
|
| 298 |
+
if torch.cuda.is_available():
|
| 299 |
+
torch.cuda.manual_seed_all(seed)
|
| 300 |
+
|
| 301 |
+
def get_stopping_criteria(req: GenerateRequest, initial_ids: torch.Tensor, tokenizer: AutoTokenizer) -> StoppingCriteriaList:
|
| 302 |
+
criteria = StoppingCriteriaList()
|
| 303 |
+
max_len_from_req = None
|
| 304 |
+
if req.max_length is not None and req.max_length > 0:
|
| 305 |
+
max_len_from_req = req.max_length
|
| 306 |
+
elif req.max_new_tokens is not None and req.max_new_tokens > 0:
|
| 307 |
+
max_len_from_req = initial_ids.shape[-1] + req.max_new_tokens
|
| 308 |
+
else:
|
| 309 |
+
max_len_from_req = initial_ids.shape[-1] + MAX_GENERATION_TOKENS
|
| 310 |
+
if max_len_from_req is not None and max_len_from_req > 0:
|
| 311 |
+
criteria.append(MaxLengthCriteria(max_len_from_req))
|
| 312 |
+
if req.min_length is not None and req.min_length > 0:
|
| 313 |
+
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id", -1)
|
| 314 |
+
criteria.append(MinLengthLogitsProcessor(initial_ids.shape[-1] + req.min_length, eos_token_id))
|
| 315 |
+
if req.stop_sequences:
|
| 316 |
+
criteria.append(StopSequenceCriteria(req.stop_sequences, tokenizer))
|
| 317 |
+
return criteria
|
| 318 |
+
|
| 319 |
+
def generate_next_token_sync(
|
| 320 |
+
input_ids,
|
| 321 |
+
past_key_values,
|
| 322 |
+
gen_cfg: GenerationConfig,
|
| 323 |
+
device: str
|
| 324 |
+
) -> Tuple[torch.Tensor, Any, Optional[float], Optional[torch.Tensor], Any, Any]:
|
| 325 |
+
with torch.no_grad():
|
| 326 |
+
outputs = global_model(
|
| 327 |
+
input_ids, past_key_values=past_key_values,
|
| 328 |
+
use_cache=gen_cfg.use_cache, return_dict=True,
|
| 329 |
+
output_attentions=gen_cfg.output_attentions,
|
| 330 |
+
output_hidden_states=gen_cfg.output_hidden_states,
|
| 331 |
+
output_scores=gen_cfg.output_scores,
|
| 332 |
+
)
|
| 333 |
+
logits = outputs.logits[:, -1, :]
|
| 334 |
+
past = outputs.past_key_values
|
| 335 |
+
scores = outputs.scores if gen_cfg.output_scores else None
|
| 336 |
+
attentions = outputs.attentions if gen_cfg.output_attentions else None
|
| 337 |
+
hidden_states = outputs.hidden_states if gen_cfg.output_hidden_states else None
|
| 338 |
+
step_logits_for_criteria = logits.clone()
|
| 339 |
+
if gen_cfg.do_normalize_logits:
|
| 340 |
+
logits = F.log_softmax(logits, dim=-1)
|
| 341 |
+
if gen_cfg.do_sample:
|
| 342 |
+
if gen_cfg.use_mirostat_mode == 1 and hasattr(global_model, 'mirostat_sample_logits'):
|
| 343 |
+
token = global_model.mirostat_sample_logits(
|
| 344 |
+
logits=logits,
|
| 345 |
+
temperature=gen_cfg.temperature,
|
| 346 |
+
mirostat_tau=gen_cfg.mirostat_tau,
|
| 347 |
+
mirostat_eta=gen_cfg.mirostat_eta
|
| 348 |
+
).unsqueeze(0).to(device)
|
| 349 |
+
else:
|
| 350 |
+
logits = logits / gen_cfg.temperature
|
| 351 |
+
if gen_cfg.temperature_cutoff is not None and gen_cfg.temperature_cutoff > 0:
|
| 352 |
+
logits = torch.where(logits < gen_cfg.temperature_cutoff, torch.tensor(-float('Inf')).to(logits.device), logits)
|
| 353 |
+
if gen_cfg.top_k:
|
| 354 |
+
topk_values, topk_indices = torch.topk(logits, gen_cfg.top_k)
|
| 355 |
+
logits[logits < topk_values[:, -1]] = -float('Inf')
|
| 356 |
+
if gen_cfg.top_p < 1.0:
|
| 357 |
+
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=True)
|
| 358 |
+
cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1)
|
| 359 |
+
sorted_indices_to_remove = cumulative_probs > gen_cfg.top_p
|
| 360 |
+
sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
|
| 361 |
+
sorted_indices_to_remove[..., 0] = False
|
| 362 |
+
indices_to_remove = sorted_indices[sorted_indices_to_remove]
|
| 363 |
+
logits[:, indices_to_remove] = -float('Inf')
|
| 364 |
+
if gen_cfg.typical_p < 1.0:
|
| 365 |
+
probs = torch.softmax(logits, dim=-1)
|
| 366 |
+
entropy = torch.distributions.Categorical(probs).entropy()
|
| 367 |
+
probs_sorted, indices_sorted = torch.sort(probs, dim=-1, descending=True)
|
| 368 |
+
cumsum_probs_sorted = torch.cumsum(probs_sorted, dim=-1)
|
| 369 |
+
mask = cumsum_probs_sorted < gen_cfg.typical_p * entropy.exp()
|
| 370 |
+
indices_to_remove = indices_sorted[~mask]
|
| 371 |
+
logits[:, indices_to_remove] = -float('Inf')
|
| 372 |
+
if gen_cfg.epsilon_cutoff is not None and gen_cfg.epsilon_cutoff > 0:
|
| 373 |
+
probs = torch.softmax(logits, dim=-1)
|
| 374 |
+
mask = probs < gen_cfg.epsilon_cutoff
|
| 375 |
+
logits[:, mask] = -float('Inf')
|
| 376 |
+
if gen_cfg.eta_cutoff is not None and gen_cfg.eta_cutoff > 0:
|
| 377 |
+
probs = torch.softmax(logits, dim=-1)
|
| 378 |
+
mask = probs > gen_cfg.eta_cutoff
|
| 379 |
+
logits[:, ~mask] = -float('Inf')
|
| 380 |
+
probs = torch.softmax(logits, dim=-1)
|
| 381 |
+
token = torch.multinomial(probs, 1)
|
| 382 |
+
else:
|
| 383 |
+
token = torch.argmax(logits, dim=-1, keepdim=True)
|
| 384 |
+
token_logprob = None
|
| 385 |
+
if gen_cfg.output_scores:
|
| 386 |
+
log_probs = F.log_softmax(step_logits_for_criteria, dim=-1)
|
| 387 |
+
if 0 <= token.squeeze().item() < log_probs.shape[-1]:
|
| 388 |
+
token_logprob = float(log_probs[:, token.squeeze()].item())
|
| 389 |
+
else:
|
| 390 |
+
token_logprob = None
|
| 391 |
+
return token, past, token_logprob, step_logits_for_criteria, attentions, hidden_states
|
| 392 |
+
|
| 393 |
+
def post_process_text(text: str, strip_trailing_whitespace: bool, remove_incomplete_sentences: bool) -> str:
|
| 394 |
+
if strip_trailing_whitespace:
|
| 395 |
+
text = text.rstrip()
|
| 396 |
+
if remove_incomplete_sentences:
|
| 397 |
+
for terminator in ['.', '!', '?', '\n']:
|
| 398 |
+
last_terminator = text.rfind(terminator)
|
| 399 |
+
if last_terminator != -1:
|
| 400 |
+
text = text[:last_terminator + 1]
|
| 401 |
+
break
|
| 402 |
+
return text
|
| 403 |
+
|
| 404 |
+
async def stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> AsyncGenerator[Union[str, Tuple[Dict[str, Any], str]], None]:
|
| 405 |
+
past = None
|
| 406 |
+
generated_tokens_count = 0
|
| 407 |
+
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id")
|
| 408 |
+
pad_token_id = req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id", eos_token_id)
|
| 409 |
+
stop_token_ids = {eos_token_id} if eos_token_id is not None else set()
|
| 410 |
+
if pad_token_id is not None and pad_token_id != eos_token_id:
|
| 411 |
+
stop_token_ids.add(pad_token_id)
|
| 412 |
+
|
| 413 |
+
current_ids = initial_ids
|
| 414 |
+
start_time = time.time()
|
| 415 |
+
total_ids_list = initial_ids.tolist()[0]
|
| 416 |
+
finish_reason = "unknown"
|
| 417 |
+
|
| 418 |
+
stopping_criteria = get_stopping_criteria(req, initial_ids, global_tokenizer)
|
| 419 |
+
|
| 420 |
+
last_step_logits = None
|
| 421 |
+
accumulated_text_for_processing = ""
|
| 422 |
+
|
| 423 |
+
try:
|
| 424 |
+
while True:
|
| 425 |
+
if generated_tokens_count >= req.max_new_tokens:
|
| 426 |
+
finish_reason = "max_new_tokens"
|
| 427 |
+
break
|
| 428 |
+
if req.max_time is not None and (time.time() - start_time) > req.max_time:
|
| 429 |
+
finish_reason = "time"
|
| 430 |
+
break
|
| 431 |
+
|
| 432 |
+
input_ids_sync = current_ids if past is None else token
|
| 433 |
+
|
| 434 |
+
token, past, token_logprob, step_logits, attentions, hidden_states = await asyncio.to_thread(
|
| 435 |
+
generate_next_token_sync,
|
| 436 |
+
input_ids_sync,
|
| 437 |
+
past,
|
| 438 |
+
gen_cfg,
|
| 439 |
+
device
|
| 440 |
+
)
|
| 441 |
+
last_step_logits = step_logits
|
| 442 |
+
|
| 443 |
+
generated_token_id = token[0].item()
|
| 444 |
+
total_ids_list.append(generated_token_id)
|
| 445 |
+
|
| 446 |
+
text = global_tokenizer.decode([generated_token_id], skip_special_tokens=True)
|
| 447 |
+
accumulated_text_for_processing += text
|
| 448 |
+
|
| 449 |
+
if req.return_only_text:
|
| 450 |
+
yield text
|
| 451 |
+
else:
|
| 452 |
+
chunk_payload: Dict[str, Any] = {
|
| 453 |
+
"type": "token",
|
| 454 |
+
"text": text,
|
| 455 |
+
"token_id": generated_token_id,
|
| 456 |
+
"generated_tokens_count": generated_tokens_count + 1,
|
| 457 |
+
}
|
| 458 |
+
if req.return_token_logprobs and token_logprob is not None:
|
| 459 |
+
chunk_payload["logprob"] = token_logprob
|
| 460 |
+
|
| 461 |
+
yield json.dumps(chunk_payload) + "\n"
|
| 462 |
+
|
| 463 |
+
if generated_token_id in stop_token_ids:
|
| 464 |
+
finish_reason = "eos_token"
|
| 465 |
+
break
|
| 466 |
+
|
| 467 |
+
current_full_ids_tensor = torch.tensor([total_ids_list], device=device)
|
| 468 |
+
if stopping_criteria(current_full_ids_tensor, step_logits):
|
| 469 |
+
finish_reason = "stopping_criteria"
|
| 470 |
+
current_len = len(total_ids_list)
|
| 471 |
+
initial_len = initial_ids.shape[-1]
|
| 472 |
+
|
| 473 |
+
max_len_crit_met = any(isinstance(c, MaxLengthCriteria) for c in stopping_criteria) and \
|
| 474 |
+
( (req.max_new_tokens is not None and current_len >= (initial_len + req.max_new_tokens)) or
|
| 475 |
+
(req.max_length is not None and current_len >= req.max_length) )
|
| 476 |
+
stop_seq_crit_met = any(isinstance(c, StopSequenceCriteria) for c in stopping_criteria) and req.stop_sequences and \
|
| 477 |
+
any(seq in global_tokenizer.decode(total_ids_list[initial_len:], skip_special_tokens=True) for seq in req.stop_sequences)
|
| 478 |
+
|
| 479 |
+
if max_len_crit_met:
|
| 480 |
+
if req.max_new_tokens is not None and current_len >= (initial_len + req.max_new_tokens):
|
| 481 |
+
finish_reason = "max_new_tokens"
|
| 482 |
+
elif req.max_length is not None and current_len >= req.max_length:
|
| 483 |
+
finish_reason = "max_length"
|
| 484 |
+
|
| 485 |
+
if stop_seq_crit_met:
|
| 486 |
+
finish_reason = "stop_sequence"
|
| 487 |
+
|
| 488 |
+
|
| 489 |
+
break
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
current_ids = token
|
| 493 |
+
generated_tokens_count += 1
|
| 494 |
+
|
| 495 |
+
final_text_raw = global_tokenizer.decode(total_ids_list[initial_ids.shape[-1]:], skip_special_tokens=True)
|
| 496 |
+
if req.stop_sequences and finish_reason == "stop_sequence":
|
| 497 |
+
for stop_seq in req.stop_sequences:
|
| 498 |
+
if stop_seq and stop_seq in final_text_raw:
|
| 499 |
+
final_text_raw = final_text_raw.split(stop_seq, 1)[0]
|
| 500 |
+
break
|
| 501 |
+
|
| 502 |
+
final_text_processed = post_process_text(final_text_raw, req.strip_trailing_whitespace, req.remove_incomplete_sentences)
|
| 503 |
+
|
| 504 |
+
|
| 505 |
+
if not req.return_only_text:
|
| 506 |
+
final_payload: Dict[str, Any] = {
|
| 507 |
+
"type": "done",
|
| 508 |
+
"total_prompt_tokens": initial_ids.shape[-1],
|
| 509 |
+
"total_generated_tokens": generated_tokens_count,
|
| 510 |
+
"total_sequence_tokens": len(total_ids_list),
|
| 511 |
+
"final_text": final_text_processed,
|
| 512 |
+
"finish_reason": finish_reason
|
| 513 |
+
}
|
| 514 |
+
yield json.dumps(final_payload) + "\n"
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
except Exception as e:
|
| 518 |
+
logging.exception("Streaming generation error:")
|
| 519 |
+
if req.return_only_text:
|
| 520 |
+
yield f"Error: {e}\n"
|
| 521 |
+
else:
|
| 522 |
+
error_payload = {"type": "error", "message": str(e)}
|
| 523 |
+
yield json.dumps(error_payload) + "\n"
|
| 524 |
+
|
| 525 |
+
finally:
|
| 526 |
+
await cleanup(device)
|
| 527 |
+
|
| 528 |
+
|
| 529 |
+
async def non_stream_generation_logic(req: GenerateRequest, initial_ids: torch.Tensor, gen_cfg: GenerationConfig, device: str) -> Dict[str, Any]:
|
| 530 |
+
try:
|
| 531 |
+
logits_processor_list = LogitsProcessorList()
|
| 532 |
+
|
| 533 |
+
stopping_criteria_list = get_stopping_criteria(req, initial_ids, global_tokenizer)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
with torch.no_grad():
|
| 537 |
+
out = global_model.generate(
|
| 538 |
+
input_ids=initial_ids,
|
| 539 |
+
generation_config=gen_cfg,
|
| 540 |
+
return_dict_in_generate=True,
|
| 541 |
+
output_scores=req.output_scores,
|
| 542 |
+
output_attentions=req.output_attentions,
|
| 543 |
+
output_hidden_states=req.output_hidden_states,
|
| 544 |
+
num_return_sequences=req.num_return_sequences,
|
| 545 |
+
bad_words_ids=req.bad_words_ids,
|
| 546 |
+
suppress_tokens=req.suppress_tokens,
|
| 547 |
+
begin_suppress_tokens=req.begin_suppress_tokens,
|
| 548 |
+
end_suppress_tokens=req.end_suppress_tokens,
|
| 549 |
+
logits_processor=logits_processor_list if logits_processor_list else None,
|
| 550 |
+
stopping_criteria=stopping_criteria_list if stopping_criteria_list else None,
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
generated_data = []
|
| 554 |
+
for i in range(req.num_return_sequences):
|
| 555 |
+
if i >= len(out.sequences):
|
| 556 |
+
break
|
| 557 |
+
|
| 558 |
+
sequence = out.sequences[i]
|
| 559 |
+
start_index = initial_ids.shape[-1]
|
| 560 |
+
generated_ids_tensor = sequence[start_index:]
|
| 561 |
+
full_sequence_ids = sequence.tolist()
|
| 562 |
+
|
| 563 |
+
text = global_tokenizer.decode(generated_ids_tensor, skip_special_tokens=True)
|
| 564 |
+
|
| 565 |
+
if req.stop_sequences:
|
| 566 |
+
for stop_seq in req.stop_sequences:
|
| 567 |
+
if stop_seq and stop_seq in text:
|
| 568 |
+
text = text.split(stop_seq, 1)[0]
|
| 569 |
+
break
|
| 570 |
+
|
| 571 |
+
text = post_process_text(text, req.strip_trailing_whitespace, req.remove_incomplete_sentences)
|
| 572 |
+
|
| 573 |
+
finish_reason = "length"
|
| 574 |
+
eos_token_id = req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id")
|
| 575 |
+
if len(generated_ids_tensor) > 0 and eos_token_id is not None and generated_ids_tensor[-1] == eos_token_id:
|
| 576 |
+
finish_reason = "eos_token"
|
| 577 |
+
elif len(generated_ids_tensor) >= gen_cfg.max_new_tokens:
|
| 578 |
+
finish_reason = "max_new_tokens"
|
| 579 |
+
elif req.max_length is not None and len(full_sequence_ids) >= req.max_length:
|
| 580 |
+
finish_reason = "max_length"
|
| 581 |
+
elif hasattr(out, 'max_time_exceeded') and out.max_time_exceeded:
|
| 582 |
+
finish_reason = "time"
|
| 583 |
+
|
| 584 |
+
if req.stop_sequences and finish_reason == "length":
|
| 585 |
+
decoded_full_output = global_tokenizer.decode(full_sequence_ids, skip_special_tokens=True)
|
| 586 |
+
if any(seq in decoded_full_output for seq in req.stop_sequences):
|
| 587 |
+
finish_reason = "stop_sequence"
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
item_data: Dict[str, Any] = {
|
| 591 |
+
"text": text if req.return_text_from_sequence else None,
|
| 592 |
+
"token_ids": generated_ids_tensor.tolist(),
|
| 593 |
+
"generated_tokens_count": len(generated_ids_tensor),
|
| 594 |
+
"finish_reason": finish_reason
|
| 595 |
+
}
|
| 596 |
+
if not req.remove_input_from_output:
|
| 597 |
+
item_data["full_sequence_token_ids"] = full_sequence_ids
|
| 598 |
+
|
| 599 |
+
if req.output_scores and hasattr(out, 'scores') and out.scores is not None:
|
| 600 |
+
item_data["scores"] = "Scores output needs custom handling (complex structure)."
|
| 601 |
+
|
| 602 |
+
if req.return_token_logprobs:
|
| 603 |
+
item_data["token_logprobs"] = "Token logprobs require parsing scores output which is complex for batched/beamed generation."
|
| 604 |
+
|
| 605 |
+
if req.output_attentions and hasattr(out, 'attentions') and out.attentions is not None:
|
| 606 |
+
item_data["attentions"] = "Attentions output needs custom handling (too large)."
|
| 607 |
+
if req.output_hidden_states and hasattr(out, 'hidden_states') and out.hidden_states is not None:
|
| 608 |
+
item_data["hidden_states"] = "Hidden states output needs custom handling (too large)."
|
| 609 |
+
if hasattr(out, 'watermark') and out.watermark is not None:
|
| 610 |
+
item_data["watermark"] = out.watermark[i] if isinstance(out.watermark, list) and len(out.watermark) > i else out.watermark
|
| 611 |
+
|
| 612 |
+
|
| 613 |
+
generated_data.append(item_data)
|
| 614 |
+
|
| 615 |
+
|
| 616 |
+
response_payload: Dict[str, Any] = {
|
| 617 |
+
"prompt_tokens": initial_ids.shape[-1],
|
| 618 |
+
"generated_sequences": generated_data,
|
| 619 |
+
}
|
| 620 |
+
if req.num_return_sequences == 1 and generated_data:
|
| 621 |
+
response_payload["total_tokens"] = response_payload["prompt_tokens"] + generated_data[0]["generated_tokens_count"]
|
| 622 |
+
|
| 623 |
+
if req.return_dict_in_generate:
|
| 624 |
+
raw_out_dict = {}
|
| 625 |
+
for key in out.keys():
|
| 626 |
+
if key not in ['sequences', 'scores', 'attentions', 'hidden_states', 'past_key_values', 'watermark', 'sequences_scores']:
|
| 627 |
+
value = out[key]
|
| 628 |
+
if isinstance(value, torch.Tensor):
|
| 629 |
+
raw_out_dict[key] = value.tolist()
|
| 630 |
+
else:
|
| 631 |
+
raw_out_dict[key] = value
|
| 632 |
+
|
| 633 |
+
response_payload["raw_generate_output"] = raw_out_dict
|
| 634 |
+
|
| 635 |
+
return response_payload
|
| 636 |
+
|
| 637 |
+
except Exception as e:
|
| 638 |
+
logging.exception("Non-streaming generation error:")
|
| 639 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")
|
| 640 |
+
|
| 641 |
+
async def cleanup(device: str):
|
| 642 |
+
if device == "cuda" and torch.cuda.is_available():
|
| 643 |
+
torch.cuda.empty_cache()
|
| 644 |
+
gc.collect()
|
| 645 |
+
|
| 646 |
+
|
| 647 |
+
@app.on_event("startup")
|
| 648 |
+
async def load_model():
|
| 649 |
+
global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR, TRUST_REMOTE_CODE_ENV
|
| 650 |
+
|
| 651 |
+
torch.set_num_threads(max(1, os.cpu_count() // 2))
|
| 652 |
+
torch.set_num_interop_threads(max(1, os.cpu_count() // 4))
|
| 653 |
+
|
| 654 |
+
torch.backends.cuda.preferred_linalg_backend = "fused" if torch.backends.cuda.is_built() else None
|
| 655 |
+
torch.backends.cudnn.benchmark = True if torch.cuda.is_available() else False
|
| 656 |
+
|
| 657 |
+
try:
|
| 658 |
+
TORCH_DTYPE = getattr(torch, TORCH_DTYPE_STR.lower(), torch.float32)
|
| 659 |
+
if TORCH_DTYPE != torch.float32:
|
| 660 |
+
logging.warning(f"Requested dtype {TORCH_DTYPE_STR} might not be fully performant on CPU. Using float32.")
|
| 661 |
+
TORCH_DTYPE = torch.float32
|
| 662 |
+
except AttributeError:
|
| 663 |
+
logging.warning(f"Invalid TORCH_DTYPE specified: {TORCH_DTYPE_STR}. Falling back to float32.")
|
| 664 |
+
TORCH_DTYPE = torch.float32
|
| 665 |
+
|
| 666 |
+
current_model_name = MODEL_NAME
|
| 667 |
+
current_trust_remote_code = TRUST_REMOTE_CODE_ENV or (current_model_name == DEFAULT_MODEL_NAME)
|
| 668 |
+
device = "cpu"
|
| 669 |
+
|
| 670 |
+
try:
|
| 671 |
+
logging.info(f"Loading config for model: {current_model_name}")
|
| 672 |
+
config = AutoConfig.from_pretrained(current_model_name, trust_remote_code=current_trust_remote_code)
|
| 673 |
+
original_config = copy.deepcopy(config)
|
| 674 |
+
|
| 675 |
+
logging.info(f"Modifying config for simplified model.")
|
| 676 |
+
|
| 677 |
+
if hasattr(config, 'num_hidden_layers'):
|
| 678 |
+
config.num_hidden_layers = 1
|
| 679 |
+
elif hasattr(config, 'num_layers'):
|
| 680 |
+
config.num_layers = 1
|
| 681 |
+
|
| 682 |
+
if hasattr(config, 'bos_token_id'):
|
| 683 |
+
config.bos_token_id = 1
|
| 684 |
+
|
| 685 |
+
if hasattr(config, 'do_sample'):
|
| 686 |
+
config.do_sample = None
|
| 687 |
+
|
| 688 |
+
if hasattr(config, 'eos_token_id'):
|
| 689 |
+
config.eos_token_id = 2
|
| 690 |
+
|
| 691 |
+
if hasattr(config, 'head_dim'):
|
| 692 |
+
config.head_dim = 96
|
| 693 |
+
|
| 694 |
+
if hasattr(config, 'hidden_size'):
|
| 695 |
+
config.hidden_size = 192
|
| 696 |
+
|
| 697 |
+
if hasattr(config, 'initializer_range'):
|
| 698 |
+
config.initializer_range = 0.02
|
| 699 |
+
|
| 700 |
+
if hasattr(config, 'intermediate_size'):
|
| 701 |
+
config.intermediate_size = 512
|
| 702 |
+
|
| 703 |
+
if hasattr(config, 'max_position_embeddings'):
|
| 704 |
+
config.max_position_embeddings = MAX_CONTEXT_TOKENS
|
| 705 |
+
|
| 706 |
+
if hasattr(config, 'n_positions'):
|
| 707 |
+
config.n_positions = MAX_CONTEXT_TOKENS
|
| 708 |
+
|
| 709 |
+
if hasattr(config, 'seq_len'):
|
| 710 |
+
config.seq_len = MAX_CONTEXT_TOKENS
|
| 711 |
+
|
| 712 |
+
if hasattr(config, 'ctx'):
|
| 713 |
+
config.ctx = MAX_CONTEXT_TOKENS
|
| 714 |
+
|
| 715 |
+
if hasattr(config, 'n_ctx'):
|
| 716 |
+
config.n_ctx = MAX_CONTEXT_TOKENS
|
| 717 |
+
|
| 718 |
+
if hasattr(config, 'max_seq_length'):
|
| 719 |
+
config.max_seq_length = MAX_CONTEXT_TOKENS
|
| 720 |
+
|
| 721 |
+
if hasattr(config, 'max_sequence_length'):
|
| 722 |
+
config.max_sequence_length = MAX_CONTEXT_TOKENS
|
| 723 |
+
|
| 724 |
+
if hasattr(config, 'max_length'):
|
| 725 |
+
config.max_length = MAX_CONTEXT_TOKENS
|
| 726 |
+
|
| 727 |
+
if hasattr(config, 'block_size'):
|
| 728 |
+
config.block_size = MAX_CONTEXT_TOKENS
|
| 729 |
+
|
| 730 |
+
if hasattr(config, 'use_cache'):
|
| 731 |
+
config.use_cache = False
|
| 732 |
+
|
| 733 |
+
if hasattr(config, 'gradient_checkpointing'):
|
| 734 |
+
config.gradient_checkpointing = True
|
| 735 |
+
|
| 736 |
+
if hasattr(config, 'torch_dtype'):
|
| 737 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16:
|
| 738 |
+
config.torch_dtype = 'bfloat16'
|
| 739 |
+
else:
|
| 740 |
+
config.torch_dtype = 'float16'
|
| 741 |
+
|
| 742 |
+
if hasattr(config, 'use_bfloat16'):
|
| 743 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16:
|
| 744 |
+
config.use_bfloat16 = True
|
| 745 |
+
else:
|
| 746 |
+
config.use_bfloat16 = False
|
| 747 |
+
|
| 748 |
+
if hasattr(config, 'attention_probs_dropout_prob'):
|
| 749 |
+
config.attention_probs_dropout_prob = 0.1
|
| 750 |
+
|
| 751 |
+
if hasattr(config, 'hidden_dropout_prob'):
|
| 752 |
+
config.hidden_dropout_prob = 0.1
|
| 753 |
+
|
| 754 |
+
if hasattr(config, 'layerdrop'):
|
| 755 |
+
config.layerdrop = 0.1
|
| 756 |
+
|
| 757 |
+
if hasattr(config, 'layer_norm_eps'):
|
| 758 |
+
config.layer_norm_eps = 1e-5
|
| 759 |
+
|
| 760 |
+
if hasattr(config, 'initializer_range'):
|
| 761 |
+
config.initializer_range = 0.02
|
| 762 |
+
|
| 763 |
+
if hasattr(config, 'rotary_pct'):
|
| 764 |
+
config.rotary_pct = 0.25
|
| 765 |
+
|
| 766 |
+
if hasattr(config, 'rotary_emb_base'):
|
| 767 |
+
config.rotary_emb_base = 10000
|
| 768 |
+
|
| 769 |
+
if hasattr(config, 'position_embedding_type'):
|
| 770 |
+
config.position_embedding_type = 'rotary'
|
| 771 |
+
|
| 772 |
+
if hasattr(config, 'activation_function'):
|
| 773 |
+
config.activation_function = 'gelu_new'
|
| 774 |
+
|
| 775 |
+
if hasattr(config, 'vocab_size'):
|
| 776 |
+
config.vocab_size = 32000
|
| 777 |
+
|
| 778 |
+
if hasattr(config, 'quantization_config'):
|
| 779 |
+
if torch.cuda.is_available():
|
| 780 |
+
config.quantization_config = {
|
| 781 |
+
'load_in_8bit': True,
|
| 782 |
+
'load_in_4bit': False,
|
| 783 |
+
'bnb_4bit_compute_dtype':'float16',
|
| 784 |
+
'bnb_4bit_use_double_quant':True,
|
| 785 |
+
'bnb_4bit_quant_type':'nf4'
|
| 786 |
+
}
|
| 787 |
+
else:
|
| 788 |
+
logging.warning("Quantization config requested but CUDA not available. Skipping quantization config modification.")
|
| 789 |
+
config.quantization_config = {}
|
| 790 |
+
|
| 791 |
+
if hasattr(config, 'load_in_8bit'):
|
| 792 |
+
if torch.cuda.is_available():
|
| 793 |
+
config.load_in_8bit = True
|
| 794 |
+
else:
|
| 795 |
+
config.load_in_8bit = False
|
| 796 |
+
|
| 797 |
+
if hasattr(config, 'load_in_4bit'):
|
| 798 |
+
if torch.cuda.is_available():
|
| 799 |
+
config.load_in_4bit = False
|
| 800 |
+
else:
|
| 801 |
+
config.load_in_4bit = False
|
| 802 |
+
|
| 803 |
+
if hasattr(config, 'tie_word_embeddings'):
|
| 804 |
+
config.tie_word_embeddings = True
|
| 805 |
+
|
| 806 |
+
if hasattr(config, 'output_attentions'):
|
| 807 |
+
config.output_attentions = False
|
| 808 |
+
|
| 809 |
+
if hasattr(config, 'output_hidden_states'):
|
| 810 |
+
config.output_hidden_states = False
|
| 811 |
+
|
| 812 |
+
if hasattr(config, 'use_cache'):
|
| 813 |
+
config.use_cache = False
|
| 814 |
+
|
| 815 |
+
logging.info(f"Loading tokenizer for model: {current_model_name}")
|
| 816 |
+
tokenizer_kwargs = {"config": original_config, "trust_remote_code": current_trust_remote_code}
|
| 817 |
+
global_tokenizer = AutoTokenizer.from_pretrained(current_model_name, **tokenizer_kwargs)
|
| 818 |
+
logging.info("Tokenizer loaded.")
|
| 819 |
+
|
| 820 |
+
logging.info(f"Loading model: {current_model_name} with modified config and dtype {TORCH_DTYPE} onto {device}")
|
| 821 |
+
|
| 822 |
+
model_kwargs = {"config": config, "torch_dtype": TORCH_DTYPE, "trust_remote_code": current_trust_remote_code}
|
| 823 |
+
|
| 824 |
+
global_model = AutoModelForCausalLM.from_pretrained(current_model_name, **model_kwargs)
|
| 825 |
+
global_model.to(device)
|
| 826 |
+
|
| 827 |
+
try:
|
| 828 |
+
global_model = torch.compile(global_model, mode="max-autotune")
|
| 829 |
+
logging.info("Model compiled with torch.compile (max-autotune mode).")
|
| 830 |
+
except Exception as e:
|
| 831 |
+
logging.warning(f"Failed to compile model with torch.compile: {e}")
|
| 832 |
+
pass
|
| 833 |
+
|
| 834 |
+
global_model.eval()
|
| 835 |
+
logging.info("Model loaded successfully.")
|
| 836 |
+
|
| 837 |
+
global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
|
| 838 |
+
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
|
| 839 |
+
if global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is not None:
|
| 840 |
+
global_tokens["pad_token_id"] = global_tokens["eos_token_id"]
|
| 841 |
+
if global_model.config.pad_token_id is None:
|
| 842 |
+
global_model.config.pad_token_id = global_tokens["pad_token_id"]
|
| 843 |
+
elif global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is None:
|
| 844 |
+
logging.warning("Neither EOS nor PAD token is defined for this tokenizer/model.")
|
| 845 |
+
if global_model.config.pad_token_id is None and global_tokens.get("pad_token_id") is not None:
|
| 846 |
+
global_model.config.pad_token_id = global_tokens["pad_token_id"]
|
| 847 |
+
|
| 848 |
+
except Exception as e:
|
| 849 |
+
logging.exception("Failed to load model or tokenizer:")
|
| 850 |
+
global_model = None
|
| 851 |
+
global_tokenizer = None
|
| 852 |
+
global_tokens = {}
|
| 853 |
+
|
| 854 |
+
html_code = """
|
| 855 |
+
<!DOCTYPE html>
|
| 856 |
+
<html lang="es">
|
| 857 |
+
<head>
|
| 858 |
+
<meta charset="UTF-8" />
|
| 859 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
| 860 |
+
<title>Chatbot Profesional</title>
|
| 861 |
+
<style>
|
| 862 |
+
body { font-family: Arial, sans-serif; margin: 20px; }
|
| 863 |
+
#chatbox { width: 100%; height: 400px; border: 1px solid #ccc; padding: 10px; overflow-y: scroll; margin-bottom: 10px; }
|
| 864 |
+
#user-input { width: calc(100% - 100px); padding: 8px; box-sizing: border-box;}
|
| 865 |
+
#send-btn { width: 90px; padding: 8px 0;}
|
| 866 |
+
#input-area { display: flex;}
|
| 867 |
+
</style>
|
| 868 |
+
</head>
|
| 869 |
+
<body>
|
| 870 |
+
<h1>Chatbot Profesional (POST API)</h1>
|
| 871 |
+
<div id="chatbox"></div>
|
| 872 |
+
<div id="input-area">
|
| 873 |
+
<input type="text" id="user-input" placeholder="Escribe tu mensaje aquí..." autocomplete="off"/>
|
| 874 |
+
<button id="send-btn">Enviar</button>
|
| 875 |
+
</div>
|
| 876 |
+
<script>
|
| 877 |
+
const chatbox = document.getElementById('chatbox');
|
| 878 |
+
const userInput = document.getElementById('user-input');
|
| 879 |
+
const sendBtn = document.getElementById('send-btn');
|
| 880 |
+
|
| 881 |
+
let conversationHistory = [];
|
| 882 |
+
const DEFAULT_SYSTEM_PROMPT = "Eres un asistente profesional y servicial.";
|
| 883 |
+
let currentSystemPrompt = DEFAULT_SYSTEM_PROMPT;
|
| 884 |
+
let botMessageElement = null;
|
| 885 |
+
|
| 886 |
+
function appendMessage(sender, text, isStreaming = false) {
|
| 887 |
+
let msg;
|
| 888 |
+
if (isStreaming && botMessageElement) {
|
| 889 |
+
botMessageElement.textContent += text;
|
| 890 |
+
} else {
|
| 891 |
+
msg = document.createElement('p');
|
| 892 |
+
msg.innerHTML = `<strong>${sender}:</strong> `;
|
| 893 |
+
const textNode = document.createTextNode(text);
|
| 894 |
+
msg.appendChild(textNode);
|
| 895 |
+
chatbox.appendChild(msg);
|
| 896 |
+
if (sender === 'Bot' && isStreaming) {
|
| 897 |
+
botMessageElement = textNode;
|
| 898 |
+
} else {
|
| 899 |
+
botMessageElement = null;
|
| 900 |
+
}
|
| 901 |
+
}
|
| 902 |
+
chatbox.scrollTop = chatbox.scrollHeight;
|
| 903 |
+
}
|
| 904 |
+
|
| 905 |
+
function updateHistory(role, content) {
|
| 906 |
+
conversationHistory.push({ "role": role, "content": content });
|
| 907 |
+
const maxHistorySize = 10;
|
| 908 |
+
if (conversationHistory.length > maxHistorySize * 2) {
|
| 909 |
+
conversationHistory = conversationHistory.slice(-(maxHistorySize * 2));
|
| 910 |
+
}
|
| 911 |
+
}
|
| 912 |
+
|
| 913 |
+
async function sendMessage() {
|
| 914 |
+
const text = userInput.value;
|
| 915 |
+
if (!text) {
|
| 916 |
+
return;
|
| 917 |
+
}
|
| 918 |
+
appendMessage('Usuario', text);
|
| 919 |
+
updateHistory("user", text);
|
| 920 |
+
userInput.value = '';
|
| 921 |
+
sendBtn.disabled = true;
|
| 922 |
+
|
| 923 |
+
botMessageElement = null;
|
| 924 |
+
|
| 925 |
+
const messagePayload = {
|
| 926 |
+
input_text: text,
|
| 927 |
+
history: conversationHistory,
|
| 928 |
+
system_prompt: currentSystemPrompt,
|
| 929 |
+
stream: true,
|
| 930 |
+
temperature: 1.0,
|
| 931 |
+
top_k: 50,
|
| 932 |
+
top_p: 1.0,
|
| 933 |
+
repetition_penalty: 1.0,
|
| 934 |
+
frequency_penalty: 0.0,
|
| 935 |
+
presence_penalty: 0.0,
|
| 936 |
+
num_beams: 1,
|
| 937 |
+
length_penalty: 1.0,
|
| 938 |
+
no_repeat_ngram_size: 0,
|
| 939 |
+
early_stopping: false,
|
| 940 |
+
do_sample: true,
|
| 941 |
+
use_mirostat: false,
|
| 942 |
+
mirostat_tau: 5.0,
|
| 943 |
+
mirostat_eta: 0.1,
|
| 944 |
+
max_new_tokens: 512,
|
| 945 |
+
num_return_sequences: 1,
|
| 946 |
+
return_token_logprobs: true
|
| 947 |
+
};
|
| 948 |
+
|
| 949 |
+
try {
|
| 950 |
+
const response = await fetch('/generate', {
|
| 951 |
+
method: 'POST',
|
| 952 |
+
headers: {
|
| 953 |
+
'Content-Type': 'application/json',
|
| 954 |
+
// Add API Key header if needed
|
| 955 |
+
// 'X-API-Key': 'YOUR_API_KEY_HERE'
|
| 956 |
+
},
|
| 957 |
+
body: JSON.stringify(messagePayload),
|
| 958 |
+
});
|
| 959 |
+
|
| 960 |
+
if (!response.ok) {
|
| 961 |
+
const errorData = await response.json();
|
| 962 |
+
throw new Error(`API Error: ${response.status} ${response.statusText} - ${errorData.detail || errorData.error}`);
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
const reader = response.body.getReader();
|
| 966 |
+
const decoder = new TextDecoder();
|
| 967 |
+
let buffer = '';
|
| 968 |
+
let currentBotResponse = "";
|
| 969 |
+
|
| 970 |
+
while (true) {
|
| 971 |
+
const { value, done } = await reader.read();
|
| 972 |
+
if (done) break;
|
| 973 |
+
|
| 974 |
+
buffer += decoder.decode(value, { stream: true });
|
| 975 |
+
|
| 976 |
+
const lines = buffer.split('\n');
|
| 977 |
+
buffer = lines.pop();
|
| 978 |
+
|
| 979 |
+
for (const line of lines) {
|
| 980 |
+
if (line.trim() === '') continue;
|
| 981 |
+
try {
|
| 982 |
+
const data = JSON.parse(line);
|
| 983 |
+
if (data.type === 'token') {
|
| 984 |
+
currentBotResponse += data.text;
|
| 985 |
+
appendMessage('Bot', data.text, true);
|
| 986 |
+
console.log('Token:', data.token_id, 'Text:', data.text, 'Logprob:', data.logprob);
|
| 987 |
+
} else if (data.type === 'done') {
|
| 988 |
+
console.log('Generation done', data);
|
| 989 |
+
if (data.total_tokens !== undefined) {
|
| 990 |
+
appendMessage('System', `Generated ${data.total_tokens} tokens. Finish reason: ${data.finish_reason}`);
|
| 991 |
+
}
|
| 992 |
+
if (data.final_text !== undefined) {
|
| 993 |
+
updateHistory("assistant", data.final_text);
|
| 994 |
+
} else if (currentBotResponse) {
|
| 995 |
+
updateHistory("assistant", currentBotResponse);
|
| 996 |
+
}
|
| 997 |
+
|
| 998 |
+
} else if (data.type === 'error') {
|
| 999 |
+
appendMessage('Error', data.message);
|
| 1000 |
+
currentBotResponse = "";
|
| 1001 |
+
}
|
| 1002 |
+
} catch (e) {
|
| 1003 |
+
console.error('Failed to parse stream chunk:', e, line);
|
| 1004 |
+
appendMessage('Error', 'Failed to process stream.');
|
| 1005 |
+
currentBotResponse = "";
|
| 1006 |
+
reader.cancel();
|
| 1007 |
+
return;
|
| 1008 |
+
}
|
| 1009 |
+
}
|
| 1010 |
+
}
|
| 1011 |
+
|
| 1012 |
+
if (buffer.trim() !== '') {
|
| 1013 |
+
try {
|
| 1014 |
+
const data = JSON.parse(buffer);
|
| 1015 |
+
if (data.type === 'token') {
|
| 1016 |
+
currentBotResponse += data.text;
|
| 1017 |
+
appendMessage('Bot', data.text, true);
|
| 1018 |
+
console.log('Token:', data.token_id, 'Text:', data.text, 'Logprob:', data.logprob);
|
| 1019 |
+
} else if (data.type === 'done') {
|
| 1020 |
+
console.log('Generation done', data);
|
| 1021 |
+
if (data.total_tokens !== undefined) {
|
| 1022 |
+
appendMessage('System', `Generated ${data.total_tokens} tokens. Finish reason: ${data.finish_reason}`);
|
| 1023 |
+
}
|
| 1024 |
+
if (data.final_text !== undefined) {
|
| 1025 |
+
updateHistory("assistant", data.final_text);
|
| 1026 |
+
} else if (currentBotResponse) {
|
| 1027 |
+
updateHistory("assistant", currentBotResponse);
|
| 1028 |
+
}
|
| 1029 |
+
} else if (data.type === 'error') {
|
| 1030 |
+
appendMessage('Error', data.message);
|
| 1031 |
+
currentBotResponse = "";
|
| 1032 |
+
}
|
| 1033 |
+
} catch (e) {
|
| 1034 |
+
console.error('Failed to parse remaining buffer:', e, buffer);
|
| 1035 |
+
appendMessage('Error', 'Failed to process remaining stream data.');
|
| 1036 |
+
currentBotResponse = "";
|
| 1037 |
+
}
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
if (currentBotResponse && !botMessageElement) {
|
| 1042 |
+
updateHistory("assistant", currentBotResponse);
|
| 1043 |
+
}
|
| 1044 |
+
botMessageElement = null;
|
| 1045 |
+
currentBotResponse = "";
|
| 1046 |
+
|
| 1047 |
+
|
| 1048 |
+
} catch (error) {
|
| 1049 |
+
console.error('Send message error:', error);
|
| 1050 |
+
appendMessage('Error', error.message || 'An unknown error occurred.');
|
| 1051 |
+
botMessageElement = null;
|
| 1052 |
+
currentBotResponse = "";
|
| 1053 |
+
} finally {
|
| 1054 |
+
sendBtn.disabled = false;
|
| 1055 |
+
}
|
| 1056 |
+
}
|
| 1057 |
+
|
| 1058 |
+
sendBtn.onclick = sendMessage;
|
| 1059 |
+
|
| 1060 |
+
userInput.addEventListener('keypress', function(event) {
|
| 1061 |
+
if (event.key === 'Enter') {
|
| 1062 |
+
event.preventDefault();
|
| 1063 |
+
sendMessage();
|
| 1064 |
+
}
|
| 1065 |
+
});
|
| 1066 |
+
|
| 1067 |
+
|
| 1068 |
+
</script>
|
| 1069 |
+
</body>
|
| 1070 |
+
</html>
|
| 1071 |
+
"""
|
| 1072 |
+
|
| 1073 |
+
@app.get("/", response_class=HTMLResponse, summary="Interactive HTML interface")
|
| 1074 |
+
async def root():
|
| 1075 |
+
return HTMLResponse(content=html_code)
|
| 1076 |
+
|
| 1077 |
+
async def check_health():
|
| 1078 |
+
model_loaded = global_model is not None
|
| 1079 |
+
tokenizer_loaded = global_tokenizer is not None
|
| 1080 |
+
status_data = {
|
| 1081 |
+
"model_loaded": model_loaded,
|
| 1082 |
+
"tokenizer_loaded": tokenizer_loaded,
|
| 1083 |
+
"status": "ok" if model_loaded and tokenizer_loaded else "loading model",
|
| 1084 |
+
"cuda_available": torch.cuda.is_available(),
|
| 1085 |
+
"cpu_cores": os.cpu_count(),
|
| 1086 |
+
"max_concurrent_generations": MAX_CONCURRENT_GENERATIONS,
|
| 1087 |
+
"currently_running_generations": MAX_CONCURRENT_GENERATIONS - generation_semaphore._value,
|
| 1088 |
+
"available_slots": generation_semaphore._value,
|
| 1089 |
+
}
|
| 1090 |
+
if torch.cuda.is_available():
|
| 1091 |
+
device_count = torch.cuda.device_count()
|
| 1092 |
+
status_data["device_count"] = device_count
|
| 1093 |
+
status_data["devices"] = []
|
| 1094 |
+
for i in range(device_count):
|
| 1095 |
+
try:
|
| 1096 |
+
device_status = {
|
| 1097 |
+
"id": i,
|
| 1098 |
+
"name": torch.cuda.get_device_name(i),
|
| 1099 |
+
"total_memory_mib": round(torch.cuda.get_device_properties(i).total_memory / (1024 * 1024), 2),
|
| 1100 |
+
"allocated_memory_mib": round(torch.cuda.memory_allocated(i) / (1024 * 1024), 2),
|
| 1101 |
+
"cached_memory_mib": round(torch.cuda.memory_reserved(i) / (1024 * 1024), 2),
|
| 1102 |
+
}
|
| 1103 |
+
status_data["devices"].append(device_status)
|
| 1104 |
+
except Exception as e:
|
| 1105 |
+
logging.error(f"Error getting GPU memory info for device {i}: {e}")
|
| 1106 |
+
status_data["devices"].append({"id": i, "error": str(e)})
|
| 1107 |
+
else:
|
| 1108 |
+
status_data["message"] = "CUDA not available. GPU resource info is not applicable."
|
| 1109 |
+
return status_data
|
| 1110 |
+
|
| 1111 |
+
async def get_config_data():
|
| 1112 |
+
torch_dtype_str_out = str(TORCH_DTYPE).split('.')[-1] if isinstance(TORCH_DTYPE, torch.dtype) else str(TORCH_DTYPE)
|
| 1113 |
+
return {
|
| 1114 |
+
"model_name": MODEL_NAME,
|
| 1115 |
+
"system_prompt_default": SYSTEM_PROMPT,
|
| 1116 |
+
"max_context_tokens": MAX_CONTEXT_TOKENS,
|
| 1117 |
+
"max_generation_tokens": MAX_GENERATION_TOKENS,
|
| 1118 |
+
"cuda_available": torch.cuda.is_available(),
|
| 1119 |
+
"model_loaded": global_model is not None,
|
| 1120 |
+
"tokenizer_loaded": global_tokenizer is not None,
|
| 1121 |
+
"max_concurrent_generations": MAX_CONCURRENT_GENERATIONS,
|
| 1122 |
+
"trust_remote_code_startup_env": TRUST_REMOTE_CODE_ENV,
|
| 1123 |
+
"trust_remote_code_effective": TRUST_REMOTE_CODE,
|
| 1124 |
+
"enable_flash_attention_2": ENABLE_FLASH_ATTENTION_2,
|
| 1125 |
+
"torch_dtype": torch_dtype_str_out,
|
| 1126 |
+
"eos_token_id": global_tokens.get("eos_token_id"),
|
| 1127 |
+
"pad_token_id": global_tokens.get("pad_token_id"),
|
| 1128 |
+
"bos_token_id": global_tokenizer.bos_token_id if global_tokenizer else None,
|
| 1129 |
+
"api_key_required": API_KEY is not None
|
| 1130 |
+
}
|
| 1131 |
+
|
| 1132 |
+
async def get_model_info_data():
|
| 1133 |
+
if global_model is None:
|
| 1134 |
+
return {"model_name": MODEL_NAME, "is_loaded": False, "message": "Model is not loaded."}
|
| 1135 |
+
try:
|
| 1136 |
+
config_dict = global_model.config.to_dict()
|
| 1137 |
+
keys_to_remove = ['torch_dtype', '_attn_implementation', 'architectures', 'id2label', 'label2id', 'torch_dtype']
|
| 1138 |
+
for key in keys_to_remove:
|
| 1139 |
+
config_dict.pop(key, None)
|
| 1140 |
+
return {
|
| 1141 |
+
"model_name": MODEL_NAME,
|
| 1142 |
+
"is_loaded": True,
|
| 1143 |
+
"device": str(global_model.device),
|
| 1144 |
+
"torch_dtype": str(global_model.dtype),
|
| 1145 |
+
"config": config_dict
|
| 1146 |
+
}
|
| 1147 |
+
except Exception as e:
|
| 1148 |
+
logging.exception("Error getting model info:")
|
| 1149 |
+
return {"model_name": MODEL_NAME, "is_loaded": True, "error": f"Error getting model info: {e}"}
|
| 1150 |
+
|
| 1151 |
+
async def internal_tokenize(text: Union[str, List[str]], add_special_tokens: bool = True, is_split_into_words: bool = False, return_token_type_ids: bool = False, padding: Union[bool, str] = False, truncation: Union[bool, str] = False, max_length: Optional[int] = None, return_tensors: Optional[str] = None, return_attention_mask: Optional[bool] = None, return_offsets_mapping: Optional[bool] = None, return_length: Optional[bool] = None, verbose: bool = False, tokenizer_kwargs: Optional[Dict[str, Any]] = None):
|
| 1152 |
+
if global_tokenizer is None:
|
| 1153 |
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tokenizer is not loaded.")
|
| 1154 |
+
try:
|
| 1155 |
+
tokenizer_kwargs_final = tokenizer_kwargs or {}
|
| 1156 |
+
return_tensors_final = return_tensors if return_tensors is not None else None
|
| 1157 |
+
if return_tensors_final is None and (return_attention_mask or return_offsets_mapping or return_length):
|
| 1158 |
+
return_tensors_final = "pt"
|
| 1159 |
+
encoded = global_tokenizer(
|
| 1160 |
+
text,
|
| 1161 |
+
add_special_tokens=add_special_tokens,
|
| 1162 |
+
return_token_type_ids=return_token_type_ids,
|
| 1163 |
+
padding=padding,
|
| 1164 |
+
truncation=truncation,
|
| 1165 |
+
max_length=max_length,
|
| 1166 |
+
is_split_into_words=is_split_into_words,
|
| 1167 |
+
return_tensors=return_tensors_final,
|
| 1168 |
+
return_attention_mask=return_attention_mask,
|
| 1169 |
+
return_offsets_mapping=return_offsets_mapping,
|
| 1170 |
+
return_length=return_length,
|
| 1171 |
+
verbose=verbose,
|
| 1172 |
+
**tokenizer_kwargs_final
|
| 1173 |
+
)
|
| 1174 |
+
return encoded
|
| 1175 |
+
except Exception as e:
|
| 1176 |
+
logging.exception("Tokenization error:")
|
| 1177 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Tokenization error: {e}")
|
| 1178 |
+
|
| 1179 |
+
async def internal_decode(token_ids: List[int], skip_special_tokens: bool = True, clean_up_tokenization_spaces: bool = True, decode_kwargs: Optional[Dict[str, Any]] = None):
|
| 1180 |
+
if global_tokenizer is None:
|
| 1181 |
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Tokenizer is not loaded.")
|
| 1182 |
+
try:
|
| 1183 |
+
decode_kwargs_final = decode_kwargs or {}
|
| 1184 |
+
text = global_tokenizer.decode(
|
| 1185 |
+
token_ids,
|
| 1186 |
+
skip_special_tokens=skip_special_tokens,
|
| 1187 |
+
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
| 1188 |
+
**decode_kwargs_final
|
| 1189 |
+
)
|
| 1190 |
+
return {"text": text}
|
| 1191 |
+
except Exception as e:
|
| 1192 |
+
logging.exception("Decoding error:")
|
| 1193 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Decoding error: {e}")
|
| 1194 |
+
|
| 1195 |
+
def update_global_system_prompt(new_prompt: str):
|
| 1196 |
+
global SYSTEM_PROMPT
|
| 1197 |
+
if new_prompt is not None:
|
| 1198 |
+
SYSTEM_PROMPT = new_prompt.strip()
|
| 1199 |
+
return {"status": "success", "message": "Global system prompt updated"}
|
| 1200 |
+
else:
|
| 1201 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="System prompt cannot be null")
|
| 1202 |
+
|
| 1203 |
+
async def internal_reload_model(req: ModelReloadRequest):
|
| 1204 |
+
global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR, TRUST_REMOTE_CODE_ENV
|
| 1205 |
+
new_model_name = req.model_name if req.model_name else MODEL_NAME
|
| 1206 |
+
new_trust_remote_code = req.trust_remote_code if req.trust_remote_code is not None else (TRUST_REMOTE_CODE_ENV or (new_model_name == DEFAULT_MODEL_NAME))
|
| 1207 |
+
new_enable_flash_attention_2 = req.enable_flash_attention_2 if req.enable_flash_attention_2 is not None else ENABLE_FLASH_ATTENTION_2
|
| 1208 |
+
new_torch_dtype_str_req = req.torch_dtype if req.torch_dtype else TORCH_DTYPE_STR
|
| 1209 |
+
try:
|
| 1210 |
+
new_torch_dtype = getattr(torch, new_torch_dtype_str_req.lower())
|
| 1211 |
+
if new_torch_dtype != torch.float32:
|
| 1212 |
+
logging.warning(f"Requested dtype {new_torch_dtype_str_req} might not be fully performant on CPU. Using float32.")
|
| 1213 |
+
new_torch_dtype = torch.float32
|
| 1214 |
+
elif not isinstance(new_torch_dtype, torch.dtype):
|
| 1215 |
+
raise AttributeError
|
| 1216 |
+
new_torch_dtype_str = str(new_torch_dtype).split('.')[-1]
|
| 1217 |
+
except AttributeError:
|
| 1218 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Invalid or unsupported torch_dtype: {new_torch_dtype_str_req}")
|
| 1219 |
+
device = "cpu"
|
| 1220 |
+
async def _reload():
|
| 1221 |
+
global global_model, global_tokenizer, global_tokens, MODEL_NAME, TRUST_REMOTE_CODE, ENABLE_FLASH_ATTENTION_2, TORCH_DTYPE, TORCH_DTYPE_STR
|
| 1222 |
+
logging.info(f"Attempting to load model: {new_model_name}")
|
| 1223 |
+
try:
|
| 1224 |
+
logging.info("Unloading current model...")
|
| 1225 |
+
await cleanup(device)
|
| 1226 |
+
if global_model is not None:
|
| 1227 |
+
del global_model
|
| 1228 |
+
global_model = None
|
| 1229 |
+
if global_tokenizer is not None:
|
| 1230 |
+
del global_tokenizer
|
| 1231 |
+
global_tokenizer = None
|
| 1232 |
+
global_tokens = {}
|
| 1233 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 1234 |
+
gc.collect()
|
| 1235 |
+
logging.info("Current model unloaded.")
|
| 1236 |
+
logging.info(f"Loading config for model: {new_model_name}")
|
| 1237 |
+
config = AutoConfig.from_pretrained(new_model_name, trust_remote_code=new_trust_remote_code)
|
| 1238 |
+
original_config = copy.deepcopy(config)
|
| 1239 |
+
|
| 1240 |
+
logging.info(f"Modifying config for simplified model.")
|
| 1241 |
+
|
| 1242 |
+
config_modifications = {
|
| 1243 |
+
'num_hidden_layers': 1,
|
| 1244 |
+
'num_layers': 1,
|
| 1245 |
+
'bos_token_id': 1,
|
| 1246 |
+
'do_sample': None,
|
| 1247 |
+
'eos_token_id': 2,
|
| 1248 |
+
'head_dim': 96,
|
| 1249 |
+
'hidden_size': 192,
|
| 1250 |
+
'initializer_range': 0.02,
|
| 1251 |
+
'intermediate_size': 512,
|
| 1252 |
+
'max_position_embeddings': MAX_CONTEXT_TOKENS,
|
| 1253 |
+
'n_positions': MAX_CONTEXT_TOKENS,
|
| 1254 |
+
'seq_len': MAX_CONTEXT_TOKENS,
|
| 1255 |
+
'ctx': MAX_CONTEXT_TOKENS,
|
| 1256 |
+
'n_ctx': MAX_CONTEXT_TOKENS,
|
| 1257 |
+
'max_seq_length': MAX_CONTEXT_TOKENS,
|
| 1258 |
+
'max_sequence_length': MAX_CONTEXT_TOKENS,
|
| 1259 |
+
'max_length': MAX_CONTEXT_TOKENS,
|
| 1260 |
+
'block_size': MAX_CONTEXT_TOKENS,
|
| 1261 |
+
'use_cache': False,
|
| 1262 |
+
'gradient_checkpointing': True,
|
| 1263 |
+
'attention_probs_dropout_prob': 0.1,
|
| 1264 |
+
'hidden_dropout_prob': 0.1,
|
| 1265 |
+
'layerdrop': 0.1,
|
| 1266 |
+
'layer_norm_eps': 1e-5,
|
| 1267 |
+
'rotary_pct': 0.25,
|
| 1268 |
+
'rotary_emb_base': 10000,
|
| 1269 |
+
'position_embedding_type': 'rotary',
|
| 1270 |
+
'activation_function': 'gelu_new',
|
| 1271 |
+
'vocab_size': 32000,
|
| 1272 |
+
'tie_word_embeddings': True,
|
| 1273 |
+
'output_attentions': False,
|
| 1274 |
+
'output_hidden_states': False,
|
| 1275 |
+
}
|
| 1276 |
+
|
| 1277 |
+
for attr, new_val in config_modifications.items():
|
| 1278 |
+
if hasattr(config, attr):
|
| 1279 |
+
if attr == 'torch_dtype':
|
| 1280 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16:
|
| 1281 |
+
setattr(config, attr, torch.bfloat16)
|
| 1282 |
+
else:
|
| 1283 |
+
setattr(config, attr, torch.float16)
|
| 1284 |
+
elif attr == 'use_bfloat16':
|
| 1285 |
+
if torch.cuda.is_available() and torch.cuda.get_device_properties(0).has_bfloat16:
|
| 1286 |
+
setattr(config, attr, True)
|
| 1287 |
+
else:
|
| 1288 |
+
setattr(config, attr, False)
|
| 1289 |
+
elif attr == 'quantization_config':
|
| 1290 |
+
if torch.cuda.is_available():
|
| 1291 |
+
setattr(config, attr, new_val)
|
| 1292 |
+
else:
|
| 1293 |
+
logging.warning(f"Quantization config requested for '{attr}' but CUDA not available. Skipping modification.")
|
| 1294 |
+
else:
|
| 1295 |
+
setattr(config, attr, new_val)
|
| 1296 |
+
elif attr in ['num_hidden_layers', 'num_layers', 'max_position_embeddings', 'n_positions', 'seq_len', 'ctx', 'n_ctx', 'max_seq_length', 'max_sequence_length', 'max_length', 'block_size']:
|
| 1297 |
+
logging.warning(f"Could not find a standard parameter '{attr}' in config for {new_model_name}. Max context/layer logic might not be fully effective.")
|
| 1298 |
+
|
| 1299 |
+
|
| 1300 |
+
logging.info(f"Loading tokenizer for model: {new_model_name}")
|
| 1301 |
+
tokenizer_kwargs = {"config": original_config, "trust_remote_code": new_trust_remote_code}
|
| 1302 |
+
if req.tokenizer_kwargs:
|
| 1303 |
+
tokenizer_kwargs.update(req.tokenizer_kwargs)
|
| 1304 |
+
tokenizer = AutoTokenizer.from_pretrained(new_model_name, **tokenizer_kwargs)
|
| 1305 |
+
logging.info("Tokenizer loaded.")
|
| 1306 |
+
|
| 1307 |
+
logging.info(f"Loading model: {new_model_name} with modified config and dtype {new_torch_dtype_str} onto {device}")
|
| 1308 |
+
model_kwargs = {"config": config, "torch_dtype": new_torch_dtype, "trust_remote_code": new_trust_remote_code}
|
| 1309 |
+
model = AutoModelForCausalLM.from_pretrained(new_model_name, **model_kwargs)
|
| 1310 |
+
model.to(device)
|
| 1311 |
+
|
| 1312 |
+
try:
|
| 1313 |
+
model = torch.compile(model, mode="max-autotune")
|
| 1314 |
+
logging.info("New model compiled with torch.compile (max-autotune mode).")
|
| 1315 |
+
except Exception as e:
|
| 1316 |
+
logging.warning(f"Failed to compile new model with torch.compile: {e}")
|
| 1317 |
+
pass
|
| 1318 |
+
model.eval()
|
| 1319 |
+
logging.info("New model loaded successfully.")
|
| 1320 |
+
global_model = model
|
| 1321 |
+
global_tokenizer = tokenizer
|
| 1322 |
+
global_tokens["eos_token_id"] = global_tokenizer.eos_token_id
|
| 1323 |
+
global_tokens["pad_token_id"] = global_tokenizer.pad_token_id
|
| 1324 |
+
if global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is not None:
|
| 1325 |
+
global_tokens["pad_token_id"] = global_tokens["eos_token_id"]
|
| 1326 |
+
if global_model.config.pad_token_id is None:
|
| 1327 |
+
global_model.config.pad_token_id = global_tokens["pad_token_id"]
|
| 1328 |
+
elif global_tokens["pad_token_id"] is None and global_tokens["eos_token_id"] is None:
|
| 1329 |
+
logging.warning("Neither EOS nor PAD token defined for new model.")
|
| 1330 |
+
if global_model.config.pad_token_id is None and global_tokens.get("pad_token_id") is not None:
|
| 1331 |
+
global_model.config.pad_token_id = global_tokens["pad_token_id"]
|
| 1332 |
+
MODEL_NAME = new_model_name
|
| 1333 |
+
TRUST_REMOTE_CODE = new_trust_remote_code
|
| 1334 |
+
ENABLE_FLASH_ATTENTION_2 = new_enable_flash_attention_2
|
| 1335 |
+
TORCH_DTYPE = new_torch_dtype
|
| 1336 |
+
TORCH_DTYPE_STR = new_torch_dtype_str
|
| 1337 |
+
if hasattr(global_tokenizer, 'use_fast'):
|
| 1338 |
+
pass
|
| 1339 |
+
logging.info(f"Model successfully reloaded to: {MODEL_NAME}")
|
| 1340 |
+
logging.info({"status": "success", "message": f"Model {new_model_name} loaded successfully."})
|
| 1341 |
+
except Exception as e:
|
| 1342 |
+
logging.exception(f"Failed to load model {new_model_name}:")
|
| 1343 |
+
global_model = None
|
| 1344 |
+
global_tokenizer = None
|
| 1345 |
+
global_tokens = {}
|
| 1346 |
+
logging.error({"status": "error", "message": f"Failed to load model {new_model_name}: {e}. Model is now unloaded."})
|
| 1347 |
+
asyncio.create_task(_reload())
|
| 1348 |
+
return {"status": "info", "message": f"Attempting to load model {new_model_name} in background. Check logs for status."}
|
| 1349 |
+
|
| 1350 |
+
async def internal_unload_model():
|
| 1351 |
+
global global_model, global_tokenizer, global_tokens
|
| 1352 |
+
device = "cpu"
|
| 1353 |
+
logging.info("Attempting to unload model.")
|
| 1354 |
+
try:
|
| 1355 |
+
await cleanup(device)
|
| 1356 |
+
if global_model is not None:
|
| 1357 |
+
del global_model
|
| 1358 |
+
global_model = None
|
| 1359 |
+
if global_tokenizer is not None:
|
| 1360 |
+
del global_tokenizer
|
| 1361 |
+
global_tokenizer = None
|
| 1362 |
+
global_tokens = {}
|
| 1363 |
+
torch.cuda.empty_cache() if torch.cuda.is_available() else None
|
| 1364 |
+
gc.collect()
|
| 1365 |
+
logging.info("Model unloaded successfully.")
|
| 1366 |
+
return {"status": "success", "message": "Model unloaded successfully."}
|
| 1367 |
+
except Exception as e:
|
| 1368 |
+
logging.exception("Failed to unload model:")
|
| 1369 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to unload model: {e}")
|
| 1370 |
+
|
| 1371 |
+
|
| 1372 |
+
@app.post("/generate", summary="Generate text", dependencies=[Depends(get_api_key)])
|
| 1373 |
+
async def generate_endpoint(req: GenerateRequest):
|
| 1374 |
+
if global_model is None or global_tokenizer is None:
|
| 1375 |
+
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model is not loaded. It may still be loading or failed to load.")
|
| 1376 |
+
device = "cpu"
|
| 1377 |
+
apply_seed(req.seed)
|
| 1378 |
+
try:
|
| 1379 |
+
initial_prompt_text = format_conversation(req.input_text, req.history, req.system_prompt)
|
| 1380 |
+
except Exception as e:
|
| 1381 |
+
logging.exception("Error formatting conversation:")
|
| 1382 |
+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=f"Error formatting conversation: {e}")
|
| 1383 |
+
try:
|
| 1384 |
+
tokenizer_encoding_kwargs = req.tokenizer_kwargs or {}
|
| 1385 |
+
|
| 1386 |
+
encoded = global_tokenizer(initial_prompt_text, return_tensors="pt", add_special_tokens=False, **tokenizer_encoding_kwargs).to(device)
|
| 1387 |
+
initial_ids_before_trunc = encoded.input_ids
|
| 1388 |
+
initial_prompt_tokens_count_before_trunc = initial_ids_before_trunc.shape[-1]
|
| 1389 |
+
|
| 1390 |
+
ids = truncate_encoded_ids(initial_ids_before_trunc, MAX_CONTEXT_TOKENS)
|
| 1391 |
+
current_prompt_tokens_count = ids.shape[-1]
|
| 1392 |
+
|
| 1393 |
+
except Exception as e:
|
| 1394 |
+
logging.exception("Tokenizer error during encoding:")
|
| 1395 |
+
await cleanup(device)
|
| 1396 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Tokenizer encoding error: {e}")
|
| 1397 |
+
if req.tokenize_only:
|
| 1398 |
+
await cleanup(device)
|
| 1399 |
+
return JSONResponse({
|
| 1400 |
+
"prompt_tokens_count": initial_prompt_tokens_count_before_trunc,
|
| 1401 |
+
"max_context_tokens": MAX_CONTEXT_TOKENS,
|
| 1402 |
+
"truncated": initial_prompt_tokens_count_before_trunc > MAX_CONTEXT_TOKENS,
|
| 1403 |
+
"input_text_processed": initial_prompt_text,
|
| 1404 |
+
"input_ids_truncated": ids.tolist()[0]
|
| 1405 |
+
})
|
| 1406 |
+
total_capacity = MAX_CONTEXT_TOKENS + MAX_GENERATION_TOKENS
|
| 1407 |
+
total_requested_seq_len = current_prompt_tokens_count + req.max_new_tokens
|
| 1408 |
+
if not req.stream and total_requested_seq_len > total_capacity:
|
| 1409 |
+
await cleanup(device)
|
| 1410 |
+
raise HTTPException(
|
| 1411 |
+
status_code=status.HTTP_400_BAD_REQUEST,
|
| 1412 |
+
detail=f"Requested sequence length ({total_requested_seq_len} tokens = {current_prompt_tokens_count} prompt + {req.max_new_tokens} new) exceeds model capacity ({total_capacity} tokens) and non-streaming is requested. Consider enabling streaming or reducing max_new_tokens."
|
| 1413 |
+
)
|
| 1414 |
+
async with generation_semaphore:
|
| 1415 |
+
try:
|
| 1416 |
+
gen_cfg = GenerationConfig(
|
| 1417 |
+
temperature=req.temperature,
|
| 1418 |
+
top_k=req.top_k,
|
| 1419 |
+
top_p=req.top_p,
|
| 1420 |
+
repetition_penalty=req.repetition_penalty,
|
| 1421 |
+
frequency_penalty=req.frequency_penalty,
|
| 1422 |
+
presence_penalty=req.presence_penalty,
|
| 1423 |
+
num_beams=req.num_beams if not req.stream else 1,
|
| 1424 |
+
length_penalty=req.length_penalty,
|
| 1425 |
+
no_repeat_ngram_size=req.no_repeat_ngram_size,
|
| 1426 |
+
early_stopping=req.early_stopping,
|
| 1427 |
+
do_sample=req.do_sample,
|
| 1428 |
+
use_mirostat_mode=1 if req.use_mirostat else 0,
|
| 1429 |
+
mirostat_tau=req.mirostat_tau,
|
| 1430 |
+
mirostat_eta=req.mirostat_eta,
|
| 1431 |
+
max_new_tokens=req.max_new_tokens,
|
| 1432 |
+
eos_token_id=req.eos_token_id_override if req.eos_token_id_override is not None else global_tokens.get("eos_token_id"),
|
| 1433 |
+
pad_token_id=req.pad_token_id_override if req.pad_token_id_override is not None else global_tokens.get("pad_token_id"),
|
| 1434 |
+
bos_token_id=req.bos_token_id_override if req.bos_token_id_override is not None else global_tokenizer.bos_token_id,
|
| 1435 |
+
num_return_sequences=req.num_return_sequences if not req.stream else 1,
|
| 1436 |
+
bad_words_ids=req.bad_words_ids,
|
| 1437 |
+
forced_bos_token_id=req.forced_bos_token_id,
|
| 1438 |
+
forced_eos_token_id=req.forced_eos_token_id,
|
| 1439 |
+
renormalize_logits=req.renormalize_logits,
|
| 1440 |
+
suppress_tokens=req.suppress_tokens,
|
| 1441 |
+
begin_suppress_tokens=req.begin_suppress_tokens,
|
| 1442 |
+
end_suppress_tokens=req.end_suppress_tokens,
|
| 1443 |
+
encoder_no_repeat_ngram_size=req.encoder_no_repeat_ngram_size,
|
| 1444 |
+
min_length=req.min_length,
|
| 1445 |
+
max_length=req.max_length,
|
| 1446 |
+
exponential_decay_length_penalty=req.exponential_decay_length_penalty,
|
| 1447 |
+
use_cache=req.use_cache,
|
| 1448 |
+
typical_p=req.typical_p,
|
| 1449 |
+
epsilon_cutoff=req.epsilon_cutoff,
|
| 1450 |
+
eta_cutoff=req.eta_cutoff,
|
| 1451 |
+
temperature_cutoff=req.temperature_cutoff,
|
| 1452 |
+
encoder_repetition_penalty=req.encoder_repetition_penalty,
|
| 1453 |
+
max_time=req.max_time,
|
| 1454 |
+
output_watermark=req.output_watermark,
|
| 1455 |
+
diversity_penalty=req.diversity_penalty,
|
| 1456 |
+
num_beam_groups=req.num_beam_groups if not req.stream else 1,
|
| 1457 |
+
length_normalization_factor=req.length_normalization_factor,
|
| 1458 |
+
min_new_tokens=req.min_new_tokens,
|
| 1459 |
+
do_normalize_logits=req.do_normalize_logits,
|
| 1460 |
+
output_scores=req.output_scores,
|
| 1461 |
+
output_attentions=req.output_attentions,
|
| 1462 |
+
output_hidden_states=req.output_hidden_states,
|
| 1463 |
+
)
|
| 1464 |
+
if req.stream:
|
| 1465 |
+
gen_cfg.use_cache = True
|
| 1466 |
+
gen_cfg.num_beams = 1
|
| 1467 |
+
gen_cfg.num_return_sequences = 1
|
| 1468 |
+
gen_cfg.num_beam_groups = 1
|
| 1469 |
+
return StreamingResponse(stream_generation_logic(req, ids, gen_cfg, device), media_type="text/plain" if req.return_only_text else "application/json")
|
| 1470 |
+
else:
|
| 1471 |
+
response_payload = await non_stream_generation_logic(req, ids, gen_cfg, device)
|
| 1472 |
+
if req.return_only_text:
|
| 1473 |
+
texts = [seq["text"] for seq in response_payload.get("generated_sequences", []) if seq.get("text") is not None]
|
| 1474 |
+
if req.num_return_sequences == 1 and texts:
|
| 1475 |
+
return PlainTextResponse(texts[0])
|
| 1476 |
+
else:
|
| 1477 |
+
return JSONResponse(texts)
|
| 1478 |
+
else:
|
| 1479 |
+
return JSONResponse(response_payload)
|
| 1480 |
+
except Exception as e:
|
| 1481 |
+
logging.exception("Generation error:")
|
| 1482 |
+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Generation error: {e}")
|
| 1483 |
+
finally:
|
| 1484 |
+
await cleanup(device)
|
| 1485 |
+
|
| 1486 |
+
if __name__ == "__main__":
|
| 1487 |
+
uvicorn.run(
|
| 1488 |
+
app, host="0.0.0.0", port=7860,
|
| 1489 |
+
log_level="critical",
|
| 1490 |
+
access_log=False
|
| 1491 |
+
)
|