Nutnell commited on
Commit
bfe0287
·
verified ·
1 Parent(s): 53098a5

Changed the max output tokens.

Browse files
Files changed (1) hide show
  1. app.py +3 -17
app.py CHANGED
@@ -1,11 +1,9 @@
1
- # app.py
2
  import os
3
  import logging
4
  import tempfile
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
7
 
8
- # --- Use writable temp dir for Hugging Face caches ---
9
  TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
10
  try:
11
  os.makedirs(TMP_CACHE, exist_ok=True)
@@ -19,15 +17,9 @@ os.environ["HF_METRICS_CACHE"] = TMP_CACHE
19
 
20
  app = FastAPI(title="DirectEd LoRA API (concise)")
21
 
22
- # ---------------------
23
- # Request Model
24
- # ---------------------
25
  class PromptRequest(BaseModel):
26
  prompt: str
27
 
28
- # ---------------------
29
- # Health & Root
30
- # ---------------------
31
  @app.get("/health")
32
  def health():
33
  return {"ok": True}
@@ -36,9 +28,6 @@ def health():
36
  def root():
37
  return {"status": "AI backend is running"}
38
 
39
- # ---------------------
40
- # Load Model on Startup
41
- # ---------------------
42
  pipe = None
43
 
44
  @app.on_event("startup")
@@ -69,22 +58,19 @@ def load_model():
69
  logging.exception("Failed to load model at startup: %s", e)
70
  pipe = None
71
 
72
- # ---------------------
73
- # Generate Endpoint
74
- # ---------------------
75
  @app.post("/generate")
76
  def generate(req: PromptRequest):
77
  if pipe is None:
78
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
79
 
80
  try:
81
- # Limit tokens to avoid huge outputs
82
- max_tokens = 150
83
 
84
  output = pipe(req.prompt, max_new_tokens=max_tokens, do_sample=True)
85
  text = output[0].get("generated_text", "").strip()
86
 
87
- # Remove repeated context if present
88
  if text.startswith(req.prompt):
89
  text = text[len(req.prompt):].strip()
90
 
 
 
1
  import os
2
  import logging
3
  import tempfile
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
 
 
7
  TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
8
  try:
9
  os.makedirs(TMP_CACHE, exist_ok=True)
 
17
 
18
  app = FastAPI(title="DirectEd LoRA API (concise)")
19
 
 
 
 
20
  class PromptRequest(BaseModel):
21
  prompt: str
22
 
 
 
 
23
  @app.get("/health")
24
  def health():
25
  return {"ok": True}
 
28
  def root():
29
  return {"status": "AI backend is running"}
30
 
 
 
 
31
  pipe = None
32
 
33
  @app.on_event("startup")
 
58
  logging.exception("Failed to load model at startup: %s", e)
59
  pipe = None
60
 
61
+
 
 
62
  @app.post("/generate")
63
  def generate(req: PromptRequest):
64
  if pipe is None:
65
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
66
 
67
  try:
68
+
69
+ max_tokens = 2048
70
 
71
  output = pipe(req.prompt, max_new_tokens=max_tokens, do_sample=True)
72
  text = output[0].get("generated_text", "").strip()
73
 
 
74
  if text.startswith(req.prompt):
75
  text = text[len(req.prompt):].strip()
76