adityabalaji commited on
Commit
82fc3eb
·
verified ·
1 Parent(s): 7aa47a6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +97 -40
app.py CHANGED
@@ -1,8 +1,14 @@
1
- # app.py — EduPrompt API (per-task lazy load, cache-safe, no cache_dir in inference)
2
 
3
- import os
 
 
 
 
4
 
5
- # ---------- Force ALL caches to /tmp ----------
 
 
6
  BASE = "/tmp"
7
  os.environ["HF_HOME"] = f"{BASE}/hf"
8
  os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
@@ -11,58 +17,99 @@ os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
11
  os.environ["XDG_CACHE_HOME"] = f"{BASE}/xdg"
12
  os.environ["TORCH_HOME"] = f"{BASE}/torch"
13
  os.environ["SENTENCEPIECE_CACHE"] = f"{BASE}/sp"
14
- for d in [
15
- os.environ["HF_HOME"], os.environ["HF_HUB_CACHE"], os.environ["HUGGINGFACE_HUB_CACHE"],
16
- os.environ["TRANSFORMERS_CACHE"], os.environ["XDG_CACHE_HOME"],
17
- os.environ["TORCH_HOME"], os.environ["SENTENCEPIECE_CACHE"]
18
- ]:
 
 
 
 
19
  os.makedirs(d, exist_ok=True)
20
 
21
- import time
22
- from fastapi import FastAPI
23
- from pydantic import BaseModel
24
- from fastapi.middleware.cors import CORSMiddleware
25
- from transformers import pipeline
26
-
27
  app = FastAPI(title="EduPrompt API")
28
  app.add_middleware(
29
  CORSMiddleware,
30
- allow_origins=["*"], # tighten in prod
31
  allow_methods=["*"],
32
  allow_headers=["*"],
33
  )
34
 
35
  @app.get("/")
36
  def health():
37
- # prove /tmp is writable
 
38
  try:
39
  with open(f"{BASE}/eduprompt_write_test.txt", "w") as f:
40
  f.write("ok")
41
- writable = True
42
  except Exception:
43
  writable = False
44
  return {
45
  "ok": True,
46
  "service": "eduprompt-api",
47
  "tmpWritable": writable,
48
- "TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"]
49
  }
50
 
51
- # ---------- lazy singletons ----------
 
 
52
  _summarizer = None
53
  _rewriter = None
54
  _proofreader = None
55
  _code_explainer = None
56
 
 
 
 
 
 
 
57
  def safe_pipeline(task: str, model_id: str):
58
- """Cache every model in its own /tmp subdir. CPU-only."""
59
- model_cache = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_"))
60
- os.makedirs(model_cache, exist_ok=True)
61
- print(f"Loading model '{model_id}' for task '{task}' into cache dir: {model_cache}")
62
- return pipeline(task, model=model_id, cache_dir=model_cache, trust_remote_code=True, device=-1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def get_model(task: str):
65
- """Load ONLY the model needed for this task."""
 
 
66
  global _summarizer, _rewriter, _proofreader, _code_explainer
67
  if task == "summarize":
68
  if _summarizer is None:
@@ -82,58 +129,68 @@ def get_model(task: str):
82
  return _code_explainer, "Salesforce/codet5p-220m"
83
  raise ValueError(f"Unsupported task '{task}'")
84
 
 
 
 
85
  class InputData(BaseModel):
86
  task: str # summarize | rewrite | proofread | explain_code
87
  input: str
88
  params: dict | None = None
89
 
90
- def filter_model_kwargs(params):
91
- # Remove keys not accepted by model.__call__()
92
  forbidden = {"cache_dir"}
93
  return {k: v for k, v in (params or {}).items() if k not in forbidden}
94
 
 
 
 
95
  @app.post("/run")
96
  async def run_task(data: InputData):
97
- start = time.time()
98
  task = (data.task or "").strip().lower()
99
  text = (data.input or "").strip()
 
100
  if not text:
101
  return {"error": "Empty input text."}
102
  if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
103
  return {"error": f"Unsupported task '{task}'."}
104
 
105
- # Load only what we need
106
  try:
107
  model, model_used = get_model(task)
108
  except Exception as e:
109
- import traceback
110
- print(traceback.format_exc())
111
  return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
112
 
113
- # Filter out forbidden kwargs
114
- params = filter_model_kwargs(data.params)
115
 
116
  try:
117
  if task == "summarize":
118
  prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
119
- output = model(prompt, max_length=120, min_length=30, truncation=True, do_sample=False, **params)[0]["summary_text"]
 
 
120
  elif task == "rewrite":
121
  prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
122
- output = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
 
123
  elif task == "proofread":
124
  prompt = f"Correct and improve grammar and style:\n{text}"
125
- output = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
 
126
  else: # explain_code
127
  prompt = f"Explain what this code does in simple language:\n{text}"
128
- output = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"]
 
129
  except Exception as e:
 
130
  import traceback
131
  print(traceback.format_exc())
132
  return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
133
 
134
  return {
135
  "enhancedPrompt": prompt,
136
- "output": output,
137
  "model": model_used,
138
- "latencyMs": round((time.time() - start) * 1000, 2),
139
- }
 
1
+ # app.py — EduPrompt API (final: per-task load, Spaces-safe caches, smart retries)
2
 
3
+ import os, time
4
+ from fastapi import FastAPI
5
+ from pydantic import BaseModel
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ from transformers import pipeline
8
 
9
+ # =========================
10
+ # Hard-force ALL caches to /tmp (writable on Spaces)
11
+ # =========================
12
  BASE = "/tmp"
13
  os.environ["HF_HOME"] = f"{BASE}/hf"
14
  os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
 
17
  os.environ["XDG_CACHE_HOME"] = f"{BASE}/xdg"
18
  os.environ["TORCH_HOME"] = f"{BASE}/torch"
19
  os.environ["SENTENCEPIECE_CACHE"] = f"{BASE}/sp"
20
+ for d in (
21
+ os.environ["HF_HOME"],
22
+ os.environ["HF_HUB_CACHE"],
23
+ os.environ["HUGGINGFACE_HUB_CACHE"],
24
+ os.environ["TRANSFORMERS_CACHE"],
25
+ os.environ["XDG_CACHE_HOME"],
26
+ os.environ["TORCH_HOME"],
27
+ os.environ["SENTENCEPIECE_CACHE"],
28
+ ):
29
  os.makedirs(d, exist_ok=True)
30
 
31
+ # =========================
32
+ # FastAPI app + CORS
33
+ # =========================
 
 
 
34
  app = FastAPI(title="EduPrompt API")
35
  app.add_middleware(
36
  CORSMiddleware,
37
+ allow_origins=["*"], # tighten in prod
38
  allow_methods=["*"],
39
  allow_headers=["*"],
40
  )
41
 
42
  @app.get("/")
43
  def health():
44
+ # prove /tmp is writable and show cache path
45
+ writable = True
46
  try:
47
  with open(f"{BASE}/eduprompt_write_test.txt", "w") as f:
48
  f.write("ok")
 
49
  except Exception:
50
  writable = False
51
  return {
52
  "ok": True,
53
  "service": "eduprompt-api",
54
  "tmpWritable": writable,
55
+ "TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"],
56
  }
57
 
58
+ # =========================
59
+ # Lazy singletons (loaded per task)
60
+ # =========================
61
  _summarizer = None
62
  _rewriter = None
63
  _proofreader = None
64
  _code_explainer = None
65
 
66
+ def _model_cache_dir(model_id: str) -> str:
67
+ # each model gets its own directory to avoid lock fights
68
+ p = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_"))
69
+ os.makedirs(p, exist_ok=True)
70
+ return p
71
+
72
  def safe_pipeline(task: str, model_id: str):
73
+ """
74
+ Build a pipeline that caches to /tmp per model.
75
+ Some pipelines reject 'cache_dir' -> retry without it.
76
+ Also handles rare permission/lock races by a short retry.
77
+ """
78
+ cache_dir = _model_cache_dir(model_id)
79
+ print(f"[init] task={task} model={model_id} cache={cache_dir}")
80
+ # Try with cache_dir
81
+ try:
82
+ return pipeline(task, model=model_id, cache_dir=cache_dir,
83
+ trust_remote_code=True, device=-1)
84
+ except ValueError as e:
85
+ # Some models complain: "model_kwargs not used: ['cache_dir']"
86
+ if "cache_dir" in str(e):
87
+ print(f"[init] {model_id} rejects cache_dir, retrying without it")
88
+ return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
89
+ raise
90
+ except OSError as e:
91
+ # Permission/lock race — wait and retry once
92
+ print(f"[init] OSError on {model_id}: {e}; retrying once")
93
+ time.sleep(1.5)
94
+ # Re-assert env (some libs re-read)
95
+ os.environ["HF_HOME"] = f"{BASE}/hf"
96
+ os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
97
+ os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
98
+ try:
99
+ return pipeline(task, model=model_id, cache_dir=cache_dir,
100
+ trust_remote_code=True, device=-1)
101
+ except ValueError as e2:
102
+ if "cache_dir" in str(e2):
103
+ print(f"[init] {model_id} rejects cache_dir on retry, fallback no cache_dir")
104
+ return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
105
+ raise
106
+ except Exception as e2:
107
+ raise
108
 
109
  def get_model(task: str):
110
+ """
111
+ Load ONLY the model needed for this task.
112
+ """
113
  global _summarizer, _rewriter, _proofreader, _code_explainer
114
  if task == "summarize":
115
  if _summarizer is None:
 
129
  return _code_explainer, "Salesforce/codet5p-220m"
130
  raise ValueError(f"Unsupported task '{task}'")
131
 
132
+ # =========================
133
+ # Request schema
134
+ # =========================
135
  class InputData(BaseModel):
136
  task: str # summarize | rewrite | proofread | explain_code
137
  input: str
138
  params: dict | None = None
139
 
140
+ def _clean_params(params: dict | None):
141
+ # Block params that some pipelines reject in generate/forward
142
  forbidden = {"cache_dir"}
143
  return {k: v for k, v in (params or {}).items() if k not in forbidden}
144
 
145
+ # =========================
146
+ # Core endpoint
147
+ # =========================
148
  @app.post("/run")
149
  async def run_task(data: InputData):
150
+ t0 = time.time()
151
  task = (data.task or "").strip().lower()
152
  text = (data.input or "").strip()
153
+
154
  if not text:
155
  return {"error": "Empty input text."}
156
  if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
157
  return {"error": f"Unsupported task '{task}'."}
158
 
159
+ # load only what we need
160
  try:
161
  model, model_used = get_model(task)
162
  except Exception as e:
 
 
163
  return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
164
 
165
+ params = _clean_params(data.params)
 
166
 
167
  try:
168
  if task == "summarize":
169
  prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
170
+ out = model(prompt, max_length=120, min_length=30,
171
+ truncation=True, do_sample=False, **params)[0]["summary_text"]
172
+
173
  elif task == "rewrite":
174
  prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
175
+ out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
176
+
177
  elif task == "proofread":
178
  prompt = f"Correct and improve grammar and style:\n{text}"
179
+ out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
180
+
181
  else: # explain_code
182
  prompt = f"Explain what this code does in simple language:\n{text}"
183
+ out = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"]
184
+
185
  except Exception as e:
186
+ # print full stack to logs for debugging; return friendly message to client
187
  import traceback
188
  print(traceback.format_exc())
189
  return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
190
 
191
  return {
192
  "enhancedPrompt": prompt,
193
+ "output": out,
194
  "model": model_used,
195
+ "latencyMs": round((time.time() - t0) * 1000, 2),
196
+ }