Spaces:
Paused
Paused
File size: 2,297 Bytes
6c0bb59 7c89c4e 6c0bb59 97cf393 6c0bb59 79529dc 7c89c4e 79529dc 6c0bb59 3e2fd2f 6c0bb59 97cf393 6c0bb59 7c89c4e 6c0bb59 7c89c4e 6c0bb59 7c89c4e 97cf393 6c0bb59 7c89c4e 6c0bb59 79529dc 7c89c4e 3e2fd2f 7c89c4e 3e2fd2f 6c0bb59 7c89c4e 3e2fd2f 79529dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
import os
import logging
import tempfile
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
try:
os.makedirs(TMP_CACHE, exist_ok=True)
except Exception:
TMP_CACHE = tempfile.gettempdir()
os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
os.environ["HF_HOME"] = TMP_CACHE
os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
os.environ["HF_METRICS_CACHE"] = TMP_CACHE
app = FastAPI(title="DirectEd LoRA API")
class PromptRequest(BaseModel):
prompt: str
max_new_tokens: int = 2048
@app.get("/")
def root():
return {"status": "AI backend is running"}
pipe = None
@app.on_event("startup")
def load_model():
global pipe
try:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
torch_dtype="auto",
)
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
model.eval()
pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
logging.info("Model and adapter loaded successfully.")
except Exception as e:
logging.exception("Failed to load model at startup: %s", e)
pipe = None
@app.post("/generate")
def generate(req: PromptRequest):
if pipe is None:
raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
try:
output = pipe(req.prompt, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=0.7)
text = output[0].get("generated_text", "").strip()
if text.startswith(req.prompt):
text = text[len(req.prompt):].strip()
if not text:
text = "No response generated by the model."
return {"response": text}
except Exception as e:
logging.exception("Generation failed for prompt '%s': %s", req.prompt, e)
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
|