Nutnell's picture
Update to strip echoing in answers
79529dc verified
import os
import logging
import tempfile
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
try:
os.makedirs(TMP_CACHE, exist_ok=True)
except Exception:
TMP_CACHE = tempfile.gettempdir()
os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
os.environ["HF_HOME"] = TMP_CACHE
os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
os.environ["HF_METRICS_CACHE"] = TMP_CACHE
app = FastAPI(title="DirectEd LoRA API")
class PromptRequest(BaseModel):
prompt: str
max_new_tokens: int = 2048
@app.get("/")
def root():
return {"status": "AI backend is running"}
pipe = None
@app.on_event("startup")
def load_model():
global pipe
try:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype="auto",
)
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
model.eval()
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
logging.info("Model and adapter loaded successfully.")
except Exception as e:
logging.exception("Failed to load model at startup: %s", e)
pipe = None
@app.post("/generate")
def generate(req: PromptRequest):
if pipe is None:
raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
try:
output = pipe(req.prompt, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=0.7)
text = output[0].get("generated_text", "").strip()
if text.startswith(req.prompt):
text = text[len(req.prompt):].strip()
if not text:
text = "No response generated by the model."
return {"response": text}
except Exception as e:
logging.exception("Generation failed for prompt '%s': %s", req.prompt, e)
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")