rayymaxx's picture
Modified the response structure
3e2fd2f
raw
history blame
3.49 kB
# app.py (refined with clean metadata)
import os
import logging
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import tempfile
from typing import List, Dict
# --- Use writable temp dir for Hugging Face caches ---
TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
try:
os.makedirs(TMP_CACHE, exist_ok=True)
except Exception:
TMP_CACHE = tempfile.gettempdir()
os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
os.environ["HF_HOME"] = TMP_CACHE
os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
os.environ["HF_METRICS_CACHE"] = TMP_CACHE
app = FastAPI(title="DirectEd LoRA API with metadata")
@app.get("/health")
def health():
return {"ok": True}
@app.get("/")
def root():
return {"status": "AI backend is running"}
class PromptRequest(BaseModel):
prompt: str
class Source(BaseModel):
name: str
url: str
class ResponseWithMetadata(BaseModel):
answer: str
sources: List[Source] = []
pipe = None
@app.on_event("startup")
def load_model():
"""Load base + LoRA adapter model at startup."""
global pipe
try:
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
from peft import PeftModel
BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA"
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
device_map="auto",
low_cpu_mem_usage=True,
torch_dtype="auto",
)
model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
model.eval()
pipe = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
device_map="auto",
)
logging.info("Model and adapter loaded successfully.")
except Exception as e:
logging.exception("Failed to load model at startup: %s", e)
pipe = None
def parse_response(raw_text: str) -> ResponseWithMetadata:
"""Extract answer and sources from raw model output."""
import re
from collections import OrderedDict
# Attempt to extract sources if present (looking for URLs)
source_pattern = r"(https?://[^\s]+)"
urls = re.findall(source_pattern, raw_text)
# Deduplicate and create simple source list
seen = set()
sources: List[Source] = []
for url in urls:
if url not in seen:
seen.add(url)
sources.append(Source(name="Reference", url=url))
# Remove sources from the text to keep answer clean
answer = re.sub(source_pattern, "", raw_text).strip()
return ResponseWithMetadata(answer=answer, sources=sources)
@app.post("/generate", response_model=ResponseWithMetadata)
def generate(req: PromptRequest):
"""Generate a concise response with optional metadata."""
if pipe is None:
raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
try:
output = pipe(req.prompt, max_new_tokens=150, do_sample=True)
full_text = output[0].get("generated_text", "").strip()
if not full_text:
raise HTTPException(status_code=500, detail="Model returned empty response.")
return parse_response(full_text)
except Exception as e:
logging.exception("Generation failed: %s", e)
raise HTTPException(status_code=500, detail=f"Generation failed: {e}")