from fastapi import FastAPI, Request, HTTPException from fastapi.responses import JSONResponse, HTMLResponse from transformers import AutoModelForCausalLM, AutoTokenizer import torch import os import logging import uvicorn # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Initialize FastAPI app = FastAPI( title="Trigger Chatbot API", description="Chatbot API using TinyLlama-1.1B-Chat model", version="1.0", ) # Get base path from environment (for Hugging Face Spaces) BASE_PATH = os.getenv("SPACE_APP_PATH", "").rstrip("/") logger.info(f"Using base path: '{BASE_PATH}'") # Load model and tokenizer try: logger.info("Loading TinyLlama tokenizer and model...") tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0") model = AutoModelForCausalLM.from_pretrained( "TinyLlama/TinyLlama-1.1B-Chat-v1.0", torch_dtype=torch.float16 # Reduces RAM usage ) model.eval() logger.info("Model loaded successfully!") except Exception as e: logger.error(f"Model loading failed: {str(e)}") raise RuntimeError("Model initialization failed") from e # In-memory chat memory chat_history = {} # Middleware for base path @app.middleware("http") async def add_base_path(request: Request, call_next): path = request.scope["path"] if BASE_PATH and path.startswith(BASE_PATH): request.scope["path"] = path[len(BASE_PATH):] return await call_next(request) @app.get("/") async def root(): return { "message": "🟢 Trigger API is running", "endpoints": { "chat": f"{BASE_PATH}/ai?query=Hello&user_id=yourname", "health": f"{BASE_PATH}/health", "reset": f"{BASE_PATH}/reset?user_id=yourname", "test": f"{BASE_PATH}/test", "docs": f"{BASE_PATH}/docs" } } @app.get("/ai") async def chat(request: Request): try: user_input = request.query_params.get("query", "").strip() user_id = request.query_params.get("user_id", "default").strip() if not user_input: raise HTTPException(status_code=400, detail="Missing 'query'") if len(user_input) > 200: raise HTTPException(status_code=400, detail="Query too long (max 200 characters)") # Prompt style: natural chat history memory = chat_history.get(user_id, []) prompt = "You are a friendly, funny AI assistant called Trigger.\n\n" for q, a in memory: prompt += f"User: {q}\nTrigger: {a}\n" prompt += f"User: {user_input}\nTrigger:" input_ids = tokenizer(prompt, return_tensors="pt").input_ids output = model.generate( input_ids, max_new_tokens=128, pad_token_id=tokenizer.eos_token_id, temperature=0.8, top_k=50, top_p=0.95, ) generated_text = tokenizer.decode(output[0], skip_special_tokens=True) response = generated_text[len(prompt):].strip().split("\n")[0] # Save history (limit to last 5 exchanges) memory.append((user_input, response)) chat_history[user_id] = memory[-5:] return {"reply": response} except torch.cuda.OutOfMemoryError: logger.error("CUDA out of memory error") if user_id in chat_history: del chat_history[user_id] raise HTTPException(status_code=500, detail="Memory error. Try again.") except Exception as e: logger.error(f"Processing error: {str(e)}") raise HTTPException(status_code=500, detail=f"Error: {str(e)}") @app.get("/health") async def health(): return { "status": "healthy", "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0", "users": len(chat_history), "base_path": BASE_PATH } @app.get("/reset") async def reset_history(user_id: str = "default"): if user_id in chat_history: del chat_history[user_id] return {"status": "success", "message": f"History cleared for user {user_id}"} @app.get("/test", response_class=HTMLResponse) async def test_page(): return f"""
Base path: {BASE_PATH}
""" # Run locally if __name__ == "__main__": uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", reload=True)