rayymaxx commited on
Commit
97cf393
·
1 Parent(s): 6c0bb59

Modified the basemodel schema

Browse files
Files changed (1) hide show
  1. app.py +9 -14
app.py CHANGED
@@ -1,25 +1,23 @@
1
- # app.py (safe, use /tmp for cache)
2
  import os
3
  import logging
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  import tempfile
7
 
8
- # --- Put caches in a writable temp dir to avoid permission errors ---
9
  TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
10
  try:
11
  os.makedirs(TMP_CACHE, exist_ok=True)
12
- except Exception as e:
13
- # if even this fails, fall back to tempfile.gettempdir()
14
  TMP_CACHE = tempfile.gettempdir()
15
 
16
- # export environment vars before importing transformers
17
  os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
18
  os.environ["HF_HOME"] = TMP_CACHE
19
  os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
20
  os.environ["HF_METRICS_CACHE"] = TMP_CACHE
21
 
22
- app = FastAPI(title="DirectEd LoRA API (safe startup)")
23
 
24
  @app.get("/health")
25
  def health():
@@ -29,10 +27,8 @@ def health():
29
  def root():
30
  return {"Status": "AI backend is running"}
31
 
32
- class Request(BaseModel):
33
  prompt: str
34
- max_new_tokens: int = 150
35
- temperature: float = 0.7
36
 
37
  pipe = None
38
 
@@ -40,12 +36,11 @@ pipe = None
40
  def load_model():
41
  global pipe
42
  try:
43
- # heavy imports done during startup
44
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
45
  from peft import PeftModel
46
 
47
  BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
48
- ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA" # <-- replace with your adapter repo
49
 
50
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
51
  base_model = AutoModelForCausalLM.from_pretrained(
@@ -65,12 +60,12 @@ def load_model():
65
  pipe = None
66
 
67
  @app.post("/generate")
68
- def generate(req: Request):
69
  if pipe is None:
70
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
71
  try:
72
- out = pipe(req.prompt, max_new_tokens=req.max_new_tokens, temperature=req.temperature, do_sample=True)
73
- return {"response": out[0]["generated_text"]}
74
  except Exception as e:
75
  logging.exception("Generation failed: %s", e)
76
  raise HTTPException(status_code=500, detail=str(e))
 
1
+ # app.py (simplified generate endpoint)
2
  import os
3
  import logging
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  import tempfile
7
 
8
+ # --- Use writable temp dir for Hugging Face caches ---
9
  TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
10
  try:
11
  os.makedirs(TMP_CACHE, exist_ok=True)
12
+ except Exception:
 
13
  TMP_CACHE = tempfile.gettempdir()
14
 
 
15
  os.environ["TRANSFORMERS_CACHE"] = TMP_CACHE
16
  os.environ["HF_HOME"] = TMP_CACHE
17
  os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
18
  os.environ["HF_METRICS_CACHE"] = TMP_CACHE
19
 
20
+ app = FastAPI(title="DirectEd LoRA API (simplified)")
21
 
22
  @app.get("/health")
23
  def health():
 
27
  def root():
28
  return {"Status": "AI backend is running"}
29
 
30
+ class PromptRequest(BaseModel):
31
  prompt: str
 
 
32
 
33
  pipe = None
34
 
 
36
  def load_model():
37
  global pipe
38
  try:
 
39
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
40
  from peft import PeftModel
41
 
42
  BASE_MODEL = "unsloth/llama-3-8b-Instruct-bnb-4bit"
43
+ ADAPTER_REPO = "rayymaxx/DirectEd-AI-LoRA"
44
 
45
  tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL)
46
  base_model = AutoModelForCausalLM.from_pretrained(
 
60
  pipe = None
61
 
62
  @app.post("/generate")
63
+ def generate(req: PromptRequest):
64
  if pipe is None:
65
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
66
  try:
67
+ output = pipe(req.prompt, max_new_tokens=150, do_sample=True)
68
+ return {"response": output[0]["generated_text"]}
69
  except Exception as e:
70
  logging.exception("Generation failed: %s", e)
71
  raise HTTPException(status_code=500, detail=str(e))