rayymaxx commited on
Commit
3e2fd2f
·
1 Parent(s): 97cf393

Modified the response structure

Browse files
Files changed (2) hide show
  1. app.py +50 -7
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,9 +1,10 @@
1
- # app.py (simplified generate endpoint)
2
  import os
3
  import logging
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  import tempfile
 
7
 
8
  # --- Use writable temp dir for Hugging Face caches ---
9
  TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
@@ -17,7 +18,7 @@ os.environ["HF_HOME"] = TMP_CACHE
17
  os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
18
  os.environ["HF_METRICS_CACHE"] = TMP_CACHE
19
 
20
- app = FastAPI(title="DirectEd LoRA API (simplified)")
21
 
22
  @app.get("/health")
23
  def health():
@@ -25,15 +26,24 @@ def health():
25
 
26
  @app.get("/")
27
  def root():
28
- return {"Status": "AI backend is running"}
29
 
30
  class PromptRequest(BaseModel):
31
  prompt: str
32
 
 
 
 
 
 
 
 
 
33
  pipe = None
34
 
35
  @app.on_event("startup")
36
  def load_model():
 
37
  global pipe
38
  try:
39
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
@@ -53,19 +63,52 @@ def load_model():
53
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
54
  model.eval()
55
 
56
- pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, device_map="auto")
 
 
 
 
 
57
  logging.info("Model and adapter loaded successfully.")
58
  except Exception as e:
59
  logging.exception("Failed to load model at startup: %s", e)
60
  pipe = None
61
 
62
- @app.post("/generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  def generate(req: PromptRequest):
 
64
  if pipe is None:
65
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
66
  try:
67
  output = pipe(req.prompt, max_new_tokens=150, do_sample=True)
68
- return {"response": output[0]["generated_text"]}
 
 
 
 
 
69
  except Exception as e:
70
  logging.exception("Generation failed: %s", e)
71
- raise HTTPException(status_code=500, detail=str(e))
 
1
+ # app.py (refined with clean metadata)
2
  import os
3
  import logging
4
  from fastapi import FastAPI, HTTPException
5
  from pydantic import BaseModel
6
  import tempfile
7
+ from typing import List, Dict
8
 
9
  # --- Use writable temp dir for Hugging Face caches ---
10
  TMP_CACHE = os.environ.get("HF_CACHE_DIR", os.path.join(tempfile.gettempdir(), "hf_cache"))
 
18
  os.environ["HF_DATASETS_CACHE"] = TMP_CACHE
19
  os.environ["HF_METRICS_CACHE"] = TMP_CACHE
20
 
21
+ app = FastAPI(title="DirectEd LoRA API with metadata")
22
 
23
  @app.get("/health")
24
  def health():
 
26
 
27
  @app.get("/")
28
  def root():
29
+ return {"status": "AI backend is running"}
30
 
31
  class PromptRequest(BaseModel):
32
  prompt: str
33
 
34
+ class Source(BaseModel):
35
+ name: str
36
+ url: str
37
+
38
+ class ResponseWithMetadata(BaseModel):
39
+ answer: str
40
+ sources: List[Source] = []
41
+
42
  pipe = None
43
 
44
  @app.on_event("startup")
45
  def load_model():
46
+ """Load base + LoRA adapter model at startup."""
47
  global pipe
48
  try:
49
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
 
63
  model = PeftModel.from_pretrained(base_model, ADAPTER_REPO)
64
  model.eval()
65
 
66
+ pipe = pipeline(
67
+ "text-generation",
68
+ model=model,
69
+ tokenizer=tokenizer,
70
+ device_map="auto",
71
+ )
72
  logging.info("Model and adapter loaded successfully.")
73
  except Exception as e:
74
  logging.exception("Failed to load model at startup: %s", e)
75
  pipe = None
76
 
77
+ def parse_response(raw_text: str) -> ResponseWithMetadata:
78
+ """Extract answer and sources from raw model output."""
79
+ import re
80
+ from collections import OrderedDict
81
+
82
+ # Attempt to extract sources if present (looking for URLs)
83
+ source_pattern = r"(https?://[^\s]+)"
84
+ urls = re.findall(source_pattern, raw_text)
85
+
86
+ # Deduplicate and create simple source list
87
+ seen = set()
88
+ sources: List[Source] = []
89
+ for url in urls:
90
+ if url not in seen:
91
+ seen.add(url)
92
+ sources.append(Source(name="Reference", url=url))
93
+
94
+ # Remove sources from the text to keep answer clean
95
+ answer = re.sub(source_pattern, "", raw_text).strip()
96
+
97
+ return ResponseWithMetadata(answer=answer, sources=sources)
98
+
99
+ @app.post("/generate", response_model=ResponseWithMetadata)
100
  def generate(req: PromptRequest):
101
+ """Generate a concise response with optional metadata."""
102
  if pipe is None:
103
  raise HTTPException(status_code=503, detail="Model not loaded. Check logs.")
104
  try:
105
  output = pipe(req.prompt, max_new_tokens=150, do_sample=True)
106
+ full_text = output[0].get("generated_text", "").strip()
107
+ if not full_text:
108
+ raise HTTPException(status_code=500, detail="Model returned empty response.")
109
+
110
+ return parse_response(full_text)
111
+
112
  except Exception as e:
113
  logging.exception("Generation failed: %s", e)
114
+ raise HTTPException(status_code=500, detail=f"Generation failed: {e}")
requirements.txt CHANGED
@@ -5,4 +5,5 @@ accelerate
5
  bitsandbytes
6
  fastapi
7
  uvicorn
8
- bitsandbytes
 
 
5
  bitsandbytes
6
  fastapi
7
  uvicorn
8
+ bitsandbytes
9
+ requests