import os import logging import tempfile from fastapi import FastAPI, HTTPException from pydantic import BaseModel TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache")) try: os.makedirs(TMP_CACHE, exist_ok=True) except Exception: TMP_CACHE = tempfile.gettempdir() os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE os.environ["HF_HOME"] = TMP_CACHE os.environ["HF_DATASETS_CACHE"] = TMP_CACHE os.environ["HF_METRICS_CACHE"] = TMP_CACHE app = FastAPI(title="DirectEd LoRA API") class PromptRequest(BaseModel): prompt: str max_new_tokens: int = 2048 @app.get("/") def root(): return {"status": "AI backend is running"} pipe = None @app.on_event("startup") def load_model(): global pipe try: from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline from peft import PeftModel BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit" ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA" tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, device_map="auto", torch_dtype="auto", ) model = PeftModel.from_pretrained(base_model, ADAPTER_REPO) model.eval() pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto") logging.info("Model and adapter loaded successfully.") except Exception as e: logging.exception("Failed to load model at startup: %s", e) pipe = None @app.post("/generate") def generate(req: PromptRequest): if pipe is None: raise HTTPException(status_code=503, detail="Model not loaded. Check logs.") try: output = pipe(req.prompt, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=0.7) text = output[0].get("generated_text", "").strip() if text.startswith(req.prompt): text = text[len(req.prompt):].strip() if not text: text = "No response generated by the model." return {"response": text} except Exception as e: logging.exception("Generation failed for prompt '%s': %s", req.prompt, e) raise HTTPException(status_code=500, detail=f"Generation failed: {e}")