File size: 4,748 Bytes
034af7a 81cbe70 7634ce7 034af7a a2bac2c 81cbe70 034af7a 4311917 81cbe70 955e737 81cbe70 034af7a 926fadf 81cbe70 034af7a 955e737 81cbe70 034af7a 81cbe70 034af7a 81cbe70 034af7a 81cbe70 034af7a 81cbe70 034af7a a3a01d7 926fadf 034af7a 955e737 81cbe70 034af7a 81cbe70 034af7a 81cbe70 955e737 81cbe70 a3a01d7 81cbe70 034af7a 81cbe70 926fadf 81cbe70 926fadf 81cbe70 034af7a 81cbe70 034af7a 81cbe70 a3a01d7 034af7a 81cbe70 955e737 81cbe70 034af7a a3a01d7 81cbe70 034af7a 81cbe70 034af7a 955e737 81cbe70 034af7a a3a01d7 81cbe70 a2bac2c 81cbe70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import logging
import uvicorn
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize FastAPI
app = FastAPI(
title="Trigger Chatbot API",
description="Chatbot API using TinyLlama-1.1B-Chat model",
version="1.0",
)
# Get base path from environment (for Hugging Face Spaces)
BASE_PATH = os.getenv("SPACE_APP_PATH", "").rstrip("/")
logger.info(f"Using base path: '{BASE_PATH}'")
# Load model and tokenizer
try:
logger.info("Loading TinyLlama tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained(
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
torch_dtype=torch.float16 # Reduces RAM usage
)
model.eval()
logger.info("Model loaded successfully!")
except Exception as e:
logger.error(f"Model loading failed: {str(e)}")
raise RuntimeError("Model initialization failed") from e
# In-memory chat memory
chat_history = {}
# Middleware for base path
@app.middleware("http")
async def add_base_path(request: Request, call_next):
path = request.scope["path"]
if BASE_PATH and path.startswith(BASE_PATH):
request.scope["path"] = path[len(BASE_PATH):]
return await call_next(request)
@app.get("/")
async def root():
return {
"message": "🟢 Trigger API is running",
"endpoints": {
"chat": f"{BASE_PATH}/ai?query=Hello&user_id=yourname",
"health": f"{BASE_PATH}/health",
"reset": f"{BASE_PATH}/reset?user_id=yourname",
"test": f"{BASE_PATH}/test",
"docs": f"{BASE_PATH}/docs"
}
}
@app.get("/ai")
async def chat(request: Request):
try:
user_input = request.query_params.get("query", "").strip()
user_id = request.query_params.get("user_id", "default").strip()
if not user_input:
raise HTTPException(status_code=400, detail="Missing 'query'")
if len(user_input) > 200:
raise HTTPException(status_code=400, detail="Query too long (max 200 characters)")
# Prompt style: natural chat history
memory = chat_history.get(user_id, [])
prompt = "You are a friendly, funny AI assistant called Trigger.\n\n"
for q, a in memory:
prompt += f"User: {q}\nTrigger: {a}\n"
prompt += f"User: {user_input}\nTrigger:"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
output = model.generate(
input_ids,
max_new_tokens=128,
pad_token_id=tokenizer.eos_token_id,
temperature=0.8,
top_k=50,
top_p=0.95,
)
generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
response = generated_text[len(prompt):].strip().split("\n")[0]
# Save history (limit to last 5 exchanges)
memory.append((user_input, response))
chat_history[user_id] = memory[-5:]
return {"reply": response}
except torch.cuda.OutOfMemoryError:
logger.error("CUDA out of memory error")
if user_id in chat_history:
del chat_history[user_id]
raise HTTPException(status_code=500, detail="Memory error. Try again.")
except Exception as e:
logger.error(f"Processing error: {str(e)}")
raise HTTPException(status_code=500, detail=f"Error: {str(e)}")
@app.get("/health")
async def health():
return {
"status": "healthy",
"model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"users": len(chat_history),
"base_path": BASE_PATH
}
@app.get("/reset")
async def reset_history(user_id: str = "default"):
if user_id in chat_history:
del chat_history[user_id]
return {"status": "success", "message": f"History cleared for user {user_id}"}
@app.get("/test", response_class=HTMLResponse)
async def test_page():
return f"""
<html>
<body>
<h1>Trigger Chatbot Test</h1>
<p>Base path: {BASE_PATH}</p>
<ul>
<li><a href="{BASE_PATH}/">Root endpoint</a></li>
<li><a href="{BASE_PATH}/ai?query=Hello&user_id=test">Chat endpoint</a></li>
<li><a href="{BASE_PATH}/health">Health check</a></li>
<li><a href="{BASE_PATH}/docs">API Docs</a></li>
</ul>
</body>
</html>
"""
# Run locally
if __name__ == "__main__":
uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", reload=True) |