rayymaxx commited on
Commit
7c89c4e
·
1 Parent(s): 3e2fd2f

Updated app

Browse files
Files changed (1) hide show
  1. app.py +37 -52
app.py CHANGED
@@ -1,10 +1,9 @@
1
- # app.py (refined with clean metadata)
2
  import os
3
  import logging
 
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
- import tempfile
7
- from typing import List, Dict
8
 
9
  # --- Use writable temp dir for Hugging Face caches ---
10
  TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
@@ -18,8 +17,17 @@ os.environ["HF_HOME"] = TMP_CACHE
18
  os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
19
  os.environ["HF_METRICS_CACHE"] = TMP_CACHE
20
 
21
- app = FastAPI(title="DirectEd LoRA API with metadata")
 
 
 
 
 
 
22
 
 
 
 
23
  @app.get("/health")
24
  def health():
25
  return {"ok": True}
@@ -28,22 +36,13 @@ def health():
28
  def root():
29
  return {"status": "AI backend is running"}
30
 
31
- class PromptRequest(BaseModel):
32
- prompt: str
33
-
34
- class Source(BaseModel):
35
- name: str
36
- url: str
37
-
38
- class ResponseWithMetadata(BaseModel):
39
- answer: str
40
- sources: List[Source] = []
41
-
42
  pipe = None
43
 
44
  @app.on_event("startup")
45
  def load_model():
46
- """Load base + LoRA adapter model at startup."""
47
  global pipe
48
  try:
49
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
@@ -63,52 +62,38 @@ def load_model():
63
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
64
  model.eval()
65
 
66
- pipe = pipeline(
67
- "text-generation",
68
- model=model,
69
- tokenizer=tokenizer,
70
- device_map="auto",
71
- )
72
  logging.info("Model and adapter loaded successfully.")
 
73
  except Exception as e:
74
  logging.exception("Failed to load model at startup: %s", e)
75
  pipe = None
76
 
77
- def parse_response(raw_text: str) -> ResponseWithMetadata:
78
- """Extract answer and sources from raw model output."""
79
- import re
80
- from collections import OrderedDict
81
-
82
- # Attempt to extract sources if present (looking for URLs)
83
- source_pattern = r"(https?://[^\s]+)"
84
- urls = re.findall(source_pattern, raw_text)
85
-
86
- # Deduplicate and create simple source list
87
- seen = set()
88
- sources: List[Source] = []
89
- for url in urls:
90
- if url not in seen:
91
- seen.add(url)
92
- sources.append(Source(name="Reference", url=url))
93
-
94
- # Remove sources from the text to keep answer clean
95
- answer = re.sub(source_pattern, "", raw_text).strip()
96
-
97
- return ResponseWithMetadata(answer=answer, sources=sources)
98
-
99
- @app.post("/generate", response_model=ResponseWithMetadata)
100
  def generate(req: PromptRequest):
101
- """Generate a concise response with optional metadata."""
102
  if pipe is None:
103
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
 
104
  try:
105
- output = pipe(req.prompt, max_new_tokens=150, do_sample=True)
106
- full_text = output[0].get("generated_text", "").strip()
107
- if not full_text:
108
- raise HTTPException(status_code=500, detail="Model returned empty response.")
 
 
 
 
 
 
 
 
 
109
 
110
- return parse_response(full_text)
111
 
112
  except Exception as e:
113
- logging.exception("Generation failed: %s", e)
114
  raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
 
1
+ # app.py
2
  import os
3
  import logging
4
+ import tempfile
5
  from fastapi import FastAPI, HTTPException
6
  from pydantic import BaseModel
 
 
7
 
8
  # --- Use writable temp dir for Hugging Face caches ---
9
  TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
 
17
  os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
18
  os.environ["HF_METRICS_CACHE"] = TMP_CACHE
19
 
20
+ app = FastAPI(title="DirectEd LoRA API (concise)")
21
+
22
+ # ---------------------
23
+ # Request Model
24
+ # ---------------------
25
+ class PromptRequest(BaseModel):
26
+ prompt: str
27
 
28
+ # ---------------------
29
+ # Health & Root
30
+ # ---------------------
31
  @app.get("/health")
32
  def health():
33
  return {"ok": True}
 
36
  def root():
37
  return {"status": "AI backend is running"}
38
 
39
+ # ---------------------
40
+ # Load Model on Startup
41
+ # ---------------------
 
 
 
 
 
 
 
 
42
  pipe = None
43
 
44
  @app.on_event("startup")
45
  def load_model():
 
46
  global pipe
47
  try:
48
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
62
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
63
  model.eval()
64
 
65
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
 
 
 
 
 
66
  logging.info("Model and adapter loaded successfully.")
67
+
68
  except Exception as e:
69
  logging.exception("Failed to load model at startup: %s", e)
70
  pipe = None
71
 
72
+ # ---------------------
73
+ # Generate Endpoint
74
+ # ---------------------
75
+ @app.post("/generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def generate(req: PromptRequest):
 
77
  if pipe is None:
78
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
79
+
80
  try:
81
+ # Limit tokens to avoid huge outputs
82
+ max_tokens = 200
83
+
84
+ output = pipe(req.prompt, max_new_tokens=max_tokens, do_sample=True)
85
+ text = output[0].get("generated_text", "").strip()
86
+
87
+ # Remove repeated context if present
88
+ if text.startswith(req.prompt):
89
+ text = text[len(req.prompt):].strip()
90
+
91
+ if not text:
92
+ logging.warning("Model returned empty response for prompt: %s", req.prompt)
93
+ text = "No response generated by the model."
94
 
95
+ return {"response": text}
96
 
97
  except Exception as e:
98
+ logging.exception("Generation failed for prompt '%s': %s", req.prompt, e)
99
  raise HTTPException(status_code=500, detail=f"Generation failed: {e}")