from fastapi import FastAPI from transformers import AutoTokenizer, AutoModelForCausalLM import torch import os # ========================================== # 1. SETUP & MODEL LOADING # ========================================== app = FastAPI() # Model ID is correct MODEL_ID = "natalieparker/LumaAI-160M-v3" # Force CPU device for deployment DEVICE = "cpu" try: print(f"🔄 Downloading and loading tokenizer from {MODEL_ID}...") tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True) print(f"🔄 Downloading and loading model from {MODEL_ID} (CPU Optimized)...") # CRITICAL FIX: Load in Float16 to halve memory consumption (441MB -> 220MB) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.float16, low_cpu_mem_usage=True # Use memory efficient loading ) # Move model to CPU memory model.to(DEVICE) print("✅ Model loaded successfully on CPU!") except Exception as e: print(f"FATAL MODEL LOAD ERROR: {e}") # The flag is set to False if loading fails model = None tokenizer = None # ========================================== # 2. ENDPOINTS # ========================================== @app.get("/") def root(): # Returns true only if model loaded successfully return {"status": "LumaAI API is live", "model_loaded": model is not None} @app.post("/generate") def generate(prompt: str): if model is None: return {"error": "Model failed to load during startup."} formatted_prompt = f"User: {prompt}\nCharacter:" inputs = tokenizer(formatted_prompt, return_tensors="pt").to(DEVICE) # Run generation without torch.no_grad() setup, as it's not needed for inference output = model.generate( **inputs, max_new_tokens=150, temperature=0.75, top_p=0.9, repetition_penalty=1.2, do_sample=True, pad_token_id=tokenizer.eos_token_id ) text = tokenizer.decode(output[0], skip_special_tokens=True) # Clean response (using final tested logic) response = text.split("Character:")[-1].split("User:")[0].strip() response = response.replace(" .", ".").replace(" ,", ",").replace(" ?", "?").replace(" !", "!") return {"response": response}