|
|
from fastapi import FastAPI, Request, HTTPException |
|
|
from fastapi.responses import JSONResponse, HTMLResponse |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
import os |
|
|
import logging |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Trigger Chatbot API", |
|
|
description="Chatbot API using TinyLlama-1.1B-Chat model", |
|
|
version="1.0", |
|
|
) |
|
|
|
|
|
|
|
|
BASE_PATH = os.getenv("SPACE_APP_PATH", "").rstrip("/") |
|
|
logger.info(f"Using base path: '{BASE_PATH}'") |
|
|
|
|
|
|
|
|
try: |
|
|
logger.info("Loading TinyLlama tokenizer and model...") |
|
|
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") |
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
"TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
|
torch_dtype=torch.float16 |
|
|
) |
|
|
model.eval() |
|
|
logger.info("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
logger.error(f"Model loading failed: {str(e)}") |
|
|
raise RuntimeError("Model initialization failed") from e |
|
|
|
|
|
|
|
|
chat_history = {} |
|
|
|
|
|
|
|
|
@app.middleware("http") |
|
|
async def add_base_path(request: Request, call_next): |
|
|
path = request.scope["path"] |
|
|
if BASE_PATH and path.startswith(BASE_PATH): |
|
|
request.scope["path"] = path[len(BASE_PATH):] |
|
|
return await call_next(request) |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"message": "🟢 Trigger API is running", |
|
|
"endpoints": { |
|
|
"chat": f"{BASE_PATH}/ai?query=Hello&user_id=yourname", |
|
|
"health": f"{BASE_PATH}/health", |
|
|
"reset": f"{BASE_PATH}/reset?user_id=yourname", |
|
|
"test": f"{BASE_PATH}/test", |
|
|
"docs": f"{BASE_PATH}/docs" |
|
|
} |
|
|
} |
|
|
|
|
|
@app.get("/ai") |
|
|
async def chat(request: Request): |
|
|
try: |
|
|
user_input = request.query_params.get("query", "").strip() |
|
|
user_id = request.query_params.get("user_id", "default").strip() |
|
|
|
|
|
if not user_input: |
|
|
raise HTTPException(status_code=400, detail="Missing 'query'") |
|
|
if len(user_input) > 200: |
|
|
raise HTTPException(status_code=400, detail="Query too long (max 200 characters)") |
|
|
|
|
|
|
|
|
memory = chat_history.get(user_id, []) |
|
|
prompt = "You are a friendly, funny AI assistant called Trigger.\n\n" |
|
|
for q, a in memory: |
|
|
prompt += f"User: {q}\nTrigger: {a}\n" |
|
|
prompt += f"User: {user_input}\nTrigger:" |
|
|
|
|
|
input_ids = tokenizer(prompt, return_tensors="pt").input_ids |
|
|
output = model.generate( |
|
|
input_ids, |
|
|
max_new_tokens=128, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
temperature=0.8, |
|
|
top_k=50, |
|
|
top_p=0.95, |
|
|
) |
|
|
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) |
|
|
response = generated_text[len(prompt):].strip().split("\n")[0] |
|
|
|
|
|
|
|
|
memory.append((user_input, response)) |
|
|
chat_history[user_id] = memory[-5:] |
|
|
|
|
|
return {"reply": response} |
|
|
|
|
|
except torch.cuda.OutOfMemoryError: |
|
|
logger.error("CUDA out of memory error") |
|
|
if user_id in chat_history: |
|
|
del chat_history[user_id] |
|
|
raise HTTPException(status_code=500, detail="Memory error. Try again.") |
|
|
except Exception as e: |
|
|
logger.error(f"Processing error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Error: {str(e)}") |
|
|
|
|
|
@app.get("/health") |
|
|
async def health(): |
|
|
return { |
|
|
"status": "healthy", |
|
|
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", |
|
|
"users": len(chat_history), |
|
|
"base_path": BASE_PATH |
|
|
} |
|
|
|
|
|
@app.get("/reset") |
|
|
async def reset_history(user_id: str = "default"): |
|
|
if user_id in chat_history: |
|
|
del chat_history[user_id] |
|
|
return {"status": "success", "message": f"History cleared for user {user_id}"} |
|
|
|
|
|
@app.get("/test", response_class=HTMLResponse) |
|
|
async def test_page(): |
|
|
return f""" |
|
|
<html> |
|
|
<body> |
|
|
<h1>Trigger Chatbot Test</h1> |
|
|
<p>Base path: {BASE_PATH}</p> |
|
|
<ul> |
|
|
<li><a href="{BASE_PATH}/">Root endpoint</a></li> |
|
|
<li><a href="{BASE_PATH}/ai?query=Hello&user_id=test">Chat endpoint</a></li> |
|
|
<li><a href="{BASE_PATH}/health">Health check</a></li> |
|
|
<li><a href="{BASE_PATH}/docs">API Docs</a></li> |
|
|
</ul> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", reload=True) |