File size: 2,817 Bytes
dde69ed
093ce27
 
dde69ed
 
4a6c839
dde69ed
 
 
093ce27
 
 
 
 
 
dde69ed
 
093ce27
 
 
4a6c839
 
093ce27
4a6c839
 
 
 
 
dde69ed
 
093ce27
dde69ed
 
 
 
 
 
 
 
 
 
 
 
 
093ce27
 
 
 
 
 
 
 
 
 
dde69ed
093ce27
 
 
 
 
 
dde69ed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
093ce27
 
dde69ed
093ce27
 
 
 
dde69ed
 
 
 
 
 
093ce27
 
 
 
 
 
 
 
 
dde69ed
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
import time
import logging
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread

# -----------------------
# Basic Logging Setup
# -----------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# -----------------------
# CORS (allow Netlify)
# -----------------------
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Replace with your Netlify domain later
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

MODEL_ID = "AshokGakr/model-tiny"

logger.info("Loading model...")

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=torch.float32,
    low_cpu_mem_usage=True
).to(device)

model.eval()

logger.info(f"Model loaded on {device}")


# -----------------------
# Root Health Check
# -----------------------
@app.get("/")
def root():
    return {"status": "API is running"}


# -----------------------
# Streaming Generator
# -----------------------
def generate_stream(prompt: str):
    logger.info("Starting generation...")
    start_time = time.time()

    inputs = tokenizer(prompt, return_tensors="pt").to(device)

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True
    )

    generation_kwargs = dict(
        **inputs,
        max_new_tokens=120,
        temperature=0.7,
        top_p=0.9,
        repetition_penalty=1.1,
        do_sample=True,
        streamer=streamer
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    for new_text in streamer:
        yield new_text

    duration = round(time.time() - start_time, 2)
    logger.info(f"Generation finished in {duration} seconds.")


# -----------------------
# Chat Endpoint
# -----------------------
@app.post("/chat")
async def chat(data: dict):
    system_prompt = data.get("system", "You are a helpful AI assistant.")
    history = data.get("history", "")
    message = data.get("message", "")

    # Trim history if too large (prevents memory overflow)
    max_history_chars = 2000
    if len(history) > max_history_chars:
        history = history[-max_history_chars:]

    logger.info("----- NEW REQUEST -----")
    logger.info(f"User message: {message}")
    logger.info(f"History length: {len(history)}")

    full_prompt = f"{system_prompt}\n{history}\nUser: {message}\nAssistant:"

    return StreamingResponse(
        generate_stream(full_prompt),
        media_type="text/plain"
    )