Update app.py
Browse files
app.py
CHANGED
|
@@ -12,6 +12,7 @@ from contextlib import asynccontextmanager
|
|
| 12 |
from fastapi import FastAPI, HTTPException
|
| 13 |
from fastapi.responses import StreamingResponse
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
|
| 15 |
from pydantic import BaseModel
|
| 16 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
| 17 |
from threading import Thread
|
|
@@ -19,12 +20,13 @@ import torch
|
|
| 19 |
|
| 20 |
# Model configuration
|
| 21 |
MODEL_NAME = "Nanbeige/Nanbeige4.1-3B"
|
| 22 |
-
MAX_LENGTH =
|
| 23 |
|
| 24 |
# Global model and tokenizer
|
| 25 |
model = None
|
| 26 |
tokenizer = None
|
| 27 |
|
|
|
|
| 28 |
|
| 29 |
class Message(BaseModel):
|
| 30 |
role: str
|
|
@@ -147,6 +149,12 @@ async def root():
|
|
| 147 |
}
|
| 148 |
|
| 149 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
@app.get("/health")
|
| 151 |
async def health():
|
| 152 |
"""Health check endpoint."""
|
|
|
|
| 12 |
from fastapi import FastAPI, HTTPException
|
| 13 |
from fastapi.responses import StreamingResponse
|
| 14 |
from fastapi.middleware.cors import CORSMiddleware
|
| 15 |
+
from fastapi.responses import FileResponse
|
| 16 |
from pydantic import BaseModel
|
| 17 |
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
|
| 18 |
from threading import Thread
|
|
|
|
| 20 |
|
| 21 |
# Model configuration
|
| 22 |
MODEL_NAME = "Nanbeige/Nanbeige4.1-3B"
|
| 23 |
+
MAX_LENGTH = 32768
|
| 24 |
|
| 25 |
# Global model and tokenizer
|
| 26 |
model = None
|
| 27 |
tokenizer = None
|
| 28 |
|
| 29 |
+
BASE_DIR = os.path.dirname(os.path.abspath(__file__))
|
| 30 |
|
| 31 |
class Message(BaseModel):
|
| 32 |
role: str
|
|
|
|
| 149 |
}
|
| 150 |
|
| 151 |
|
| 152 |
+
@app.get("/index", response_class=FileResponse)
|
| 153 |
+
async def serve_chat():
|
| 154 |
+
"""Serve chat.html as index."""
|
| 155 |
+
return FileResponse(os.path.join(BASE_DIR, "index.html"))
|
| 156 |
+
|
| 157 |
+
|
| 158 |
@app.get("/health")
|
| 159 |
async def health():
|
| 160 |
"""Health check endpoint."""
|