Spaces:
Paused
Paused
| # app.py (refined with clean metadata) | |
| import os | |
| import logging | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| import tempfile | |
| from typing import List, Dict | |
| # --- Use writable temp dir for Hugging Face caches --- | |
| TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache")) | |
| try: | |
| os.makedirs(TMP_CACHE, exist_ok=True) | |
| except Exception: | |
| TMP_CACHE = tempfile.gettempdir() | |
| os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE | |
| os.environ["HF_HOME"] = TMP_CACHE | |
| os.environ["HF_DATASETS_CACHE"] = TMP_CACHE | |
| os.environ["HF_METRICS_CACHE"] = TMP_CACHE | |
| app = FastAPI(title="DirectEd LoRA API with metadata") | |
| def health(): | |
| return {"ok": True} | |
| def root(): | |
| return {"status": "AI backend is running"} | |
| class PromptRequest(BaseModel): | |
| prompt: str | |
| class Source(BaseModel): | |
| name: str | |
| url: str | |
| class ResponseWithMetadata(BaseModel): | |
| answer: str | |
| sources: List[Source] = [] | |
| pipe = None | |
| def load_model(): | |
| """Load base + LoRA adapter model at startup.""" | |
| global pipe | |
| try: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline | |
| from peft import PeftModel | |
| BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit" | |
| ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA" | |
| tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL) | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| BASE_MODEL, | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| torch_dtype="auto", | |
| ) | |
| model = PeftModel.from_pretrained(base_model, ADAPTER_REPO) | |
| model.eval() | |
| pipe = pipeline( | |
| "text-generation", | |
| model=model, | |
| tokenizer=tokenizer, | |
| device_map="auto", | |
| ) | |
| logging.info("Model and adapter loaded successfully.") | |
| except Exception as e: | |
| logging.exception("Failed to load model at startup: %s", e) | |
| pipe = None | |
| def parse_response(raw_text: str) -> ResponseWithMetadata: | |
| """Extract answer and sources from raw model output.""" | |
| import re | |
| from collections import OrderedDict | |
| # Attempt to extract sources if present (looking for URLs) | |
| source_pattern = r"(https?://[^\s]+)" | |
| urls = re.findall(source_pattern, raw_text) | |
| # Deduplicate and create simple source list | |
| seen = set() | |
| sources: List[Source] = [] | |
| for url in urls: | |
| if url not in seen: | |
| seen.add(url) | |
| sources.append(Source(name="Reference", url=url)) | |
| # Remove sources from the text to keep answer clean | |
| answer = re.sub(source_pattern, "", raw_text).strip() | |
| return ResponseWithMetadata(answer=answer, sources=sources) | |
| def generate(req: PromptRequest): | |
| """Generate a concise response with optional metadata.""" | |
| if pipe is None: | |
| raise HTTPException(status_code=503, detail="Model not loaded. Check logs.") | |
| try: | |
| output = pipe(req.prompt, max_new_tokens=150, do_sample=True) | |
| full_text = output[0].get("generated_text", "").strip() | |
| if not full_text: | |
| raise HTTPException(status_code=500, detail="Model returned empty response.") | |
| return parse_response(full_text) | |
| except Exception as e: | |
| logging.exception("Generation failed: %s", e) | |
| raise HTTPException(status_code=500, detail=f"Generation failed: {e}") | |