File size: 2,297 Bytes
6c0bb59
 
7c89c4e
6c0bb59
 
 
 
 
 
97cf393
6c0bb59
 
 
 
 
 
 
79529dc
7c89c4e
 
 
79529dc
6c0bb59
 
 
3e2fd2f
6c0bb59
 
 
 
 
 
 
 
 
 
 
97cf393
6c0bb59
 
 
 
 
 
 
 
 
 
 
7c89c4e
6c0bb59
7c89c4e
6c0bb59
 
 
 
7c89c4e
97cf393
6c0bb59
 
7c89c4e
6c0bb59
79529dc
7c89c4e
 
 
 
 
 
 
3e2fd2f
7c89c4e
3e2fd2f
6c0bb59
7c89c4e
3e2fd2f
79529dc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import os
import logging
import tempfile
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel

TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
try:
    os.makedirs(TMP_CACHE, exist_ok=True)
except Exception:
    TMP_CACHE = tempfile.gettempdir()

os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
os.environ["HF_HOME"] = TMP_CACHE
os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
os.environ["HF_METRICS_CACHE"] = TMP_CACHE

app = FastAPI(title="DirectEd LoRA API")

class PromptRequest(BaseModel):
    prompt: str
    max_new_tokens: int = 2048

@app.get("/")
def root():
    return {"status": "AI backend is running"}

pipe = None

@app.on_event("startup")
def load_model():
    global pipe
    try:
        from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
        from peft import PeftModel

        BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
        ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA"

        tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            device_map="auto",
            torch_dtype="auto",
        )

        model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
        model.eval()

        pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
        logging.info("Model and adapter loaded successfully.")

    except Exception as e:
        logging.exception("Failed to load model at startup: %s", e)
        pipe = None

@app.post("/generate")
def generate(req: PromptRequest):
    if pipe is None:
        raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")

    try:
        output = pipe(req.prompt, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=0.7)
        text = output[0].get("generated_text", "").strip()

        if text.startswith(req.prompt):
            text = text[len(req.prompt):].strip()

        if not text:
            text = "No response generated by the model."

        return {"response": text}

    except Exception as e:
        logging.exception("Generation failed for prompt '%s': %s", req.prompt, e)
        raise HTTPException(status_code=500, detail=f"Generation failed: {e}")