Update app.py
Browse files
app.py
CHANGED
|
@@ -12,8 +12,8 @@ logger = logging.getLogger(__name__)
|
|
| 12 |
|
| 13 |
# Initialize FastAPI
|
| 14 |
app = FastAPI(
|
| 15 |
-
title="
|
| 16 |
-
description="Chatbot API using
|
| 17 |
version="1.0",
|
| 18 |
)
|
| 19 |
|
|
@@ -23,9 +23,12 @@ logger.info(f"Using base path: '{BASE_PATH}'")
|
|
| 23 |
|
| 24 |
# Load model and tokenizer
|
| 25 |
try:
|
| 26 |
-
logger.info("Loading tokenizer and model...")
|
| 27 |
-
tokenizer = AutoTokenizer.from_pretrained("
|
| 28 |
-
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
|
|
|
|
|
|
| 29 |
model.eval()
|
| 30 |
logger.info("Model loaded successfully!")
|
| 31 |
except Exception as e:
|
|
@@ -46,7 +49,7 @@ async def add_base_path(request: Request, call_next):
|
|
| 46 |
@app.get("/")
|
| 47 |
async def root():
|
| 48 |
return {
|
| 49 |
-
"message": "🟢
|
| 50 |
"endpoints": {
|
| 51 |
"chat": f"{BASE_PATH}/ai?query=Hello&user_id=yourname",
|
| 52 |
"health": f"{BASE_PATH}/health",
|
|
@@ -67,7 +70,7 @@ async def chat(request: Request):
|
|
| 67 |
if len(user_input) > 200:
|
| 68 |
raise HTTPException(status_code=400, detail="Query too long (max 200 characters)")
|
| 69 |
|
| 70 |
-
# Prompt style:
|
| 71 |
memory = chat_history.get(user_id, [])
|
| 72 |
prompt = "You are a friendly, funny AI assistant called Trigger.\n\n"
|
| 73 |
for q, a in memory:
|
|
@@ -105,7 +108,7 @@ async def chat(request: Request):
|
|
| 105 |
async def health():
|
| 106 |
return {
|
| 107 |
"status": "healthy",
|
| 108 |
-
"model": "
|
| 109 |
"users": len(chat_history),
|
| 110 |
"base_path": BASE_PATH
|
| 111 |
}
|
|
@@ -121,7 +124,7 @@ async def test_page():
|
|
| 121 |
return f"""
|
| 122 |
<html>
|
| 123 |
<body>
|
| 124 |
-
<h1>
|
| 125 |
<p>Base path: {BASE_PATH}</p>
|
| 126 |
<ul>
|
| 127 |
<li><a href="{BASE_PATH}/">Root endpoint</a></li>
|
|
|
|
| 12 |
|
| 13 |
# Initialize FastAPI
|
| 14 |
app = FastAPI(
|
| 15 |
+
title="Trigger Chatbot API",
|
| 16 |
+
description="Chatbot API using TinyLlama-1.1B-Chat model",
|
| 17 |
version="1.0",
|
| 18 |
)
|
| 19 |
|
|
|
|
| 23 |
|
| 24 |
# Load model and tokenizer
|
| 25 |
try:
|
| 26 |
+
logger.info("Loading TinyLlama tokenizer and model...")
|
| 27 |
+
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
|
| 28 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 29 |
+
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 30 |
+
torch_dtype=torch.float16 # Reduces RAM usage
|
| 31 |
+
)
|
| 32 |
model.eval()
|
| 33 |
logger.info("Model loaded successfully!")
|
| 34 |
except Exception as e:
|
|
|
|
| 49 |
@app.get("/")
|
| 50 |
async def root():
|
| 51 |
return {
|
| 52 |
+
"message": "🟢 Trigger API is running",
|
| 53 |
"endpoints": {
|
| 54 |
"chat": f"{BASE_PATH}/ai?query=Hello&user_id=yourname",
|
| 55 |
"health": f"{BASE_PATH}/health",
|
|
|
|
| 70 |
if len(user_input) > 200:
|
| 71 |
raise HTTPException(status_code=400, detail="Query too long (max 200 characters)")
|
| 72 |
|
| 73 |
+
# Prompt style: natural chat history
|
| 74 |
memory = chat_history.get(user_id, [])
|
| 75 |
prompt = "You are a friendly, funny AI assistant called Trigger.\n\n"
|
| 76 |
for q, a in memory:
|
|
|
|
| 108 |
async def health():
|
| 109 |
return {
|
| 110 |
"status": "healthy",
|
| 111 |
+
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
|
| 112 |
"users": len(chat_history),
|
| 113 |
"base_path": BASE_PATH
|
| 114 |
}
|
|
|
|
| 124 |
return f"""
|
| 125 |
<html>
|
| 126 |
<body>
|
| 127 |
+
<h1>Trigger Chatbot Test</h1>
|
| 128 |
<p>Base path: {BASE_PATH}</p>
|
| 129 |
<ul>
|
| 130 |
<li><a href="{BASE_PATH}/">Root endpoint</a></li>
|