LumaAI-API / app.py
natalieparker's picture
Update app.py
ec891af verified
from fastapi import FastAPI
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import os
# ==========================================
# 1. SETUP & MODEL LOADING
# ==========================================
app = FastAPI()
# Model ID is correct
MODEL_ID = "natalieparker/LumaAI-160M-v3"
# Force CPU device for deployment
DEVICE = "cpu"
try:
print(f"πŸ”„ Downloading and loading tokenizer from {MODEL_ID}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
print(f"πŸ”„ Downloading and loading model from {MODEL_ID} (CPU Optimized)...")
# CRITICAL FIX: Load in Float16 to halve memory consumption (441MB -> 220MB)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True # Use memory efficient loading
)
# Move model to CPU memory
model.to(DEVICE)
print("βœ… Model loaded successfully on CPU!")
except Exception as e:
print(f"FATAL MODEL LOAD ERROR: {e}")
# The flag is set to False if loading fails
model = None
tokenizer = None
# ==========================================
# 2. ENDPOINTS
# ==========================================
@app.get("/")
def root():
# Returns true only if model loaded successfully
return {"status": "LumaAI API is live", "model_loaded": model is not None}
@app.post("/generate")
def generate(prompt: str):
if model is None:
return {"error": "Model failed to load during startup."}
formatted_prompt = f"User: {prompt}\nCharacter:"
inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE)
# Run generation without torch.no_grad() setup, as it's not needed for inference
output = model.generate(
**inputs,
max_new_tokens=150,
temperature=0.75,
top_p=0.9,
repetition_penalty=1.2,
do_sample=True,
pad_token_id=tokenizer.eos_token_id
)
text = tokenizer.decode(output[0], skip_special_tokens=True)
# Clean response (using final tested logic)
response = text.split("Character:")[-1].split("User:")[0].strip()
response = response.replace(" .", ".").replace(" ,", ",").replace(" ?", "?").replace(" !", "!")
return {"response": response}