Hmm / app.py
Trigger82's picture
Update app.py
955e737 verified
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import logging
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI
app = FastAPI(
title="Trigger Chatbot API",
description="Chatbot API using TinyLlama-1.1B-Chat model",
version="1.0",
)
# Get base path from environment (for Hugging Face Spaces)
BASE_PATH = os.getenv("SPACE_APP_PATH", "").rstrip("/")
logger.info(f"Using base path: '{BASE_PATH}'")
# Load model and tokenizer
try:
logger.info("Loading TinyLlama tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.float16 # Reduces RAM usage
)
model.eval()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise RuntimeError("Model initialization failed") from e
# In-memory chat memory
chat_history = {}
# Middleware for base path
@app.middleware("http")
async def add_base_path(request: Request, call_next):
path = request.scope["path"]
if BASE_PATH and path.startswith(BASE_PATH):
request.scope["path"] = path[len(BASE_PATH):]
return await call_next(request)
@app.get("/")
async def root():
return {
"message": "🟢 Trigger API is running",
"endpoints": {
"chat": f"{BASE_PATH}/ai?query=Hello&user_id=yourname",
"health": f"{BASE_PATH}/health",
"reset": f"{BASE_PATH}/reset?user_id=yourname",
"test": f"{BASE_PATH}/test",
"docs": f"{BASE_PATH}/docs"
}
}
@app.get("/ai")
async def chat(request: Request):
try:
user_input = request.query_params.get("query", "").strip()
user_id = request.query_params.get("user_id", "default").strip()
if not user_input:
raise HTTPException(status_code=400, detail="Missing 'query'")
if len(user_input) > 200:
raise HTTPException(status_code=400, detail="Query too long (max 200 characters)")
# Prompt style: natural chat history
memory = chat_history.get(user_id, [])
prompt = "You are a friendly, funny AI assistant called Trigger.\n\n"
for q, a in memory:
prompt += f"User: {q}\nTrigger: {a}\n"
prompt += f"User: {user_input}\nTrigger:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(
input_ids,
max_new_tokens=128,
pad_token_id=tokenizer.eos_token_id,
temperature=0.8,
top_k=50,
top_p=0.95,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
response = generated_text[len(prompt):].strip().split("\n")[0]
# Save history (limit to last 5 exchanges)
memory.append((user_input, response))
chat_history[user_id] = memory[-5:]
return {"reply": response}
except torch.cuda.OutOfMemoryError:
logger.error("CUDA out of memory error")
if user_id in chat_history:
del chat_history[user_id]
raise HTTPException(status_code=500, detail="Memory error. Try again.")
except Exception as e:
logger.error(f"Processing error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
@app.get("/health")
async def health():
return {
"status": "healthy",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"users": len(chat_history),
"base_path": BASE_PATH
}
@app.get("/reset")
async def reset_history(user_id: str = "default"):
if user_id in chat_history:
del chat_history[user_id]
return {"status": "success", "message": f"History cleared for user {user_id}"}
@app.get("/test", response_class=HTMLResponse)
async def test_page():
return f"""
<html>
<body>
<h1>Trigger Chatbot Test</h1>
<p>Base path: {BASE_PATH}</p>
<ul>
<li><a href="{BASE_PATH}/">Root endpoint</a></li>
<li><a href="{BASE_PATH}/ai?query=Hello&user_id=test">Chat endpoint</a></li>
<li><a href="{BASE_PATH}/health">Health check</a></li>
<li><a href="{BASE_PATH}/docs">API Docs</a></li>
</ul>
</body>
</html>
"""
# Run locally
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", reload=True)