File size: 2,817 Bytes
dde69ed 093ce27 dde69ed 4a6c839 dde69ed 093ce27 dde69ed 093ce27 4a6c839 093ce27 4a6c839 dde69ed 093ce27 dde69ed 093ce27 dde69ed 093ce27 dde69ed 093ce27 dde69ed 093ce27 dde69ed 093ce27 dde69ed | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | import torch
import time
import logging
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
# -----------------------
# Basic Logging Setup
# -----------------------
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# -----------------------
# CORS (allow Netlify)
# -----------------------
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Replace with your Netlify domain later
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
MODEL_ID = "AshokGakr/model-tiny"
logger.info("Loading model...")
device = "cuda" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float32,
low_cpu_mem_usage=True
).to(device)
model.eval()
logger.info(f"Model loaded on {device}")
# -----------------------
# Root Health Check
# -----------------------
@app.get("/")
def root():
return {"status": "API is running"}
# -----------------------
# Streaming Generator
# -----------------------
def generate_stream(prompt: str):
logger.info("Starting generation...")
start_time = time.time()
inputs = tokenizer(prompt, return_tensors="pt").to(device)
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True
)
generation_kwargs = dict(
**inputs,
max_new_tokens=120,
temperature=0.7,
top_p=0.9,
repetition_penalty=1.1,
do_sample=True,
streamer=streamer
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
for new_text in streamer:
yield new_text
duration = round(time.time() - start_time, 2)
logger.info(f"Generation finished in {duration} seconds.")
# -----------------------
# Chat Endpoint
# -----------------------
@app.post("/chat")
async def chat(data: dict):
system_prompt = data.get("system", "You are a helpful AI assistant.")
history = data.get("history", "")
message = data.get("message", "")
# Trim history if too large (prevents memory overflow)
max_history_chars = 2000
if len(history) > max_history_chars:
history = history[-max_history_chars:]
logger.info("----- NEW REQUEST -----")
logger.info(f"User message: {message}")
logger.info(f"History length: {len(history)}")
full_prompt = f"{system_prompt}\n{history}\nUser: {message}\nAssistant:"
return StreamingResponse(
generate_stream(full_prompt),
media_type="text/plain"
) |