Nutnell commited on
Commit
79529dc
·
verified ·
1 Parent(s): bfe0287

Update to strip echoing in answers

Browse files
Files changed (1) hide show
  1. app.py +4 -12
app.py CHANGED
@@ -15,14 +15,11 @@ os.environ["HF_HOME"] = TMP_CACHE
15
  os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
16
  os.environ["HF_METRICS_CACHE"] = TMP_CACHE
17
 
18
- app = FastAPI(title="DirectEd LoRA API (concise)")
19
 
20
  class PromptRequest(BaseModel):
21
  prompt: str
22
-
23
- @app.get("/health")
24
- def health():
25
- return {"ok": True}
26
 
27
  @app.get("/")
28
  def root():
@@ -44,7 +41,6 @@ def load_model():
44
  base_model = AutoModelForCausalLM.from_pretrained(
45
  BASE_MODEL,
46
  device_map="auto",
47
- low_cpu_mem_usage=True,
48
  torch_dtype="auto",
49
  )
50
 
@@ -58,24 +54,19 @@ def load_model():
58
  logging.exception("Failed to load model at startup: %s", e)
59
  pipe = None
60
 
61
-
62
  @app.post("/generate")
63
  def generate(req: PromptRequest):
64
  if pipe is None:
65
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
66
 
67
  try:
68
-
69
- max_tokens = 2048
70
-
71
- output = pipe(req.prompt, max_new_tokens=max_tokens, do_sample=True)
72
  text = output[0].get("generated_text", "").strip()
73
 
74
  if text.startswith(req.prompt):
75
  text = text[len(req.prompt):].strip()
76
 
77
  if not text:
78
- logging.warning("Model returned empty response for prompt: %s", req.prompt)
79
  text = "No response generated by the model."
80
 
81
  return {"response": text}
@@ -83,3 +74,4 @@ def generate(req: PromptRequest):
83
  except Exception as e:
84
  logging.exception("Generation failed for prompt '%s': %s", req.prompt, e)
85
  raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
 
 
15
  os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
16
  os.environ["HF_METRICS_CACHE"] = TMP_CACHE
17
 
18
+ app = FastAPI(title="DirectEd LoRA API")
19
 
20
  class PromptRequest(BaseModel):
21
  prompt: str
22
+ max_new_tokens: int = 2048
 
 
 
23
 
24
  @app.get("/")
25
  def root():
 
41
  base_model = AutoModelForCausalLM.from_pretrained(
42
  BASE_MODEL,
43
  device_map="auto",
 
44
  torch_dtype="auto",
45
  )
46
 
 
54
  logging.exception("Failed to load model at startup: %s", e)
55
  pipe = None
56
 
 
57
  @app.post("/generate")
58
  def generate(req: PromptRequest):
59
  if pipe is None:
60
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
61
 
62
  try:
63
+ output = pipe(req.prompt, max_new_tokens=req.max_new_tokens, do_sample=True, temperature=0.7)
 
 
 
64
  text = output[0].get("generated_text", "").strip()
65
 
66
  if text.startswith(req.prompt):
67
  text = text[len(req.prompt):].strip()
68
 
69
  if not text:
 
70
  text = "No response generated by the model."
71
 
72
  return {"response": text}
 
74
  except Exception as e:
75
  logging.exception("Generation failed for prompt '%s': %s", req.prompt, e)
76
  raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
77
+