File size: 4,748 Bytes
034af7a
 
81cbe70
7634ce7
034af7a
 
a2bac2c
 
81cbe70
034af7a
 
4311917
81cbe70
 
955e737
 
81cbe70
 
 
 
034af7a
 
926fadf
81cbe70
034af7a
955e737
 
 
 
 
 
81cbe70
 
034af7a
81cbe70
 
034af7a
81cbe70
 
034af7a
81cbe70
034af7a
81cbe70
034af7a
 
 
 
a3a01d7
926fadf
034af7a
 
955e737
81cbe70
 
 
 
 
 
 
034af7a
 
 
81cbe70
034af7a
81cbe70
 
 
 
 
 
 
 
955e737
81cbe70
 
 
 
 
 
 
a3a01d7
81cbe70
 
 
034af7a
81cbe70
 
926fadf
81cbe70
 
926fadf
81cbe70
 
 
034af7a
 
 
81cbe70
 
 
 
 
034af7a
81cbe70
 
a3a01d7
034af7a
 
 
81cbe70
955e737
81cbe70
034af7a
 
a3a01d7
81cbe70
 
 
 
 
 
034af7a
81cbe70
034af7a
 
 
955e737
81cbe70
 
 
 
 
 
 
034af7a
 
 
a3a01d7
81cbe70
a2bac2c
81cbe70
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from fastapi import FastAPI, Request, HTTPException
from fastapi.responses import JSONResponse, HTMLResponse
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import os
import logging
import uvicorn

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI
app = FastAPI(
    title="Trigger Chatbot API",
    description="Chatbot API using TinyLlama-1.1B-Chat model",
    version="1.0",
)

# Get base path from environment (for Hugging Face Spaces)
BASE_PATH = os.getenv("SPACE_APP_PATH", "").rstrip("/")
logger.info(f"Using base path: '{BASE_PATH}'")

# Load model and tokenizer
try:
    logger.info("Loading TinyLlama tokenizer and model...")
    tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
    model = AutoModelForCausalLM.from_pretrained(
        "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        torch_dtype=torch.float16  # Reduces RAM usage
    )
    model.eval()
    logger.info("Model loaded successfully!")
except Exception as e:
    logger.error(f"Model loading failed: {str(e)}")
    raise RuntimeError("Model initialization failed") from e

# In-memory chat memory
chat_history = {}

# Middleware for base path
@app.middleware("http")
async def add_base_path(request: Request, call_next):
    path = request.scope["path"]
    if BASE_PATH and path.startswith(BASE_PATH):
        request.scope["path"] = path[len(BASE_PATH):]
    return await call_next(request)

@app.get("/")
async def root():
    return {
        "message": "🟢 Trigger API is running",
        "endpoints": {
            "chat": f"{BASE_PATH}/ai?query=Hello&user_id=yourname",
            "health": f"{BASE_PATH}/health",
            "reset": f"{BASE_PATH}/reset?user_id=yourname",
            "test": f"{BASE_PATH}/test",
            "docs": f"{BASE_PATH}/docs"
        }
    }

@app.get("/ai")
async def chat(request: Request):
    try:
        user_input = request.query_params.get("query", "").strip()
        user_id = request.query_params.get("user_id", "default").strip()

        if not user_input:
            raise HTTPException(status_code=400, detail="Missing 'query'")
        if len(user_input) > 200:
            raise HTTPException(status_code=400, detail="Query too long (max 200 characters)")

        # Prompt style: natural chat history
        memory = chat_history.get(user_id, [])
        prompt = "You are a friendly, funny AI assistant called Trigger.\n\n"
        for q, a in memory:
            prompt += f"User: {q}\nTrigger: {a}\n"
        prompt += f"User: {user_input}\nTrigger:"

        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        output = model.generate(
            input_ids,
            max_new_tokens=128,
            pad_token_id=tokenizer.eos_token_id,
            temperature=0.8,
            top_k=50,
            top_p=0.95,
        )
        generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
        response = generated_text[len(prompt):].strip().split("\n")[0]

        # Save history (limit to last 5 exchanges)
        memory.append((user_input, response))
        chat_history[user_id] = memory[-5:]

        return {"reply": response}

    except torch.cuda.OutOfMemoryError:
        logger.error("CUDA out of memory error")
        if user_id in chat_history:
            del chat_history[user_id]
        raise HTTPException(status_code=500, detail="Memory error. Try again.")
    except Exception as e:
        logger.error(f"Processing error: {str(e)}")
        raise HTTPException(status_code=500, detail=f"Error: {str(e)}")

@app.get("/health")
async def health():
    return {
        "status": "healthy",
        "model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
        "users": len(chat_history),
        "base_path": BASE_PATH
    }

@app.get("/reset")
async def reset_history(user_id: str = "default"):
    if user_id in chat_history:
        del chat_history[user_id]
    return {"status": "success", "message": f"History cleared for user {user_id}"}

@app.get("/test", response_class=HTMLResponse)
async def test_page():
    return f"""
    <html>
        <body>
            <h1>Trigger Chatbot Test</h1>
            <p>Base path: {BASE_PATH}</p>
            <ul>
                <li><a href="{BASE_PATH}/">Root endpoint</a></li>
                <li><a href="{BASE_PATH}/ai?query=Hello&user_id=test">Chat endpoint</a></li>
                <li><a href="{BASE_PATH}/health">Health check</a></li>
                <li><a href="{BASE_PATH}/docs">API Docs</a></li>
            </ul>
        </body>
    </html>
    """

# Run locally
if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=7860, log_level="info", reload=True)