adityabalaji commited on
Commit
c884159
·
verified ·
1 Parent(s): 82fc3eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -43
app.py CHANGED
@@ -1,14 +1,5 @@
1
- # app.py — EduPrompt API (final: per-task load, Spaces-safe caches, smart retries)
2
-
3
- import os, time
4
- from fastapi import FastAPI
5
- from pydantic import BaseModel
6
- from fastapi.middleware.cors import CORSMiddleware
7
- from transformers import pipeline
8
-
9
- # =========================
10
- # Hard-force ALL caches to /tmp (writable on Spaces)
11
- # =========================
12
  BASE = "/tmp"
13
  os.environ["HF_HOME"] = f"{BASE}/hf"
14
  os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
@@ -28,20 +19,22 @@ for d in (
28
  ):
29
  os.makedirs(d, exist_ok=True)
30
 
31
- # =========================
32
- # FastAPI app + CORS
33
- # =========================
 
 
 
34
  app = FastAPI(title="EduPrompt API")
35
  app.add_middleware(
36
  CORSMiddleware,
37
- allow_origins=["*"], # tighten in prod
38
  allow_methods=["*"],
39
  allow_headers=["*"],
40
  )
41
 
42
  @app.get("/")
43
  def health():
44
- # prove /tmp is writable and show cache path
45
  writable = True
46
  try:
47
  with open(f"{BASE}/eduprompt_write_test.txt", "w") as f:
@@ -53,45 +46,33 @@ def health():
53
  "service": "eduprompt-api",
54
  "tmpWritable": writable,
55
  "TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"],
 
56
  }
57
 
58
- # =========================
59
- # Lazy singletons (loaded per task)
60
- # =========================
61
  _summarizer = None
62
  _rewriter = None
63
  _proofreader = None
64
  _code_explainer = None
65
 
66
  def _model_cache_dir(model_id: str) -> str:
67
- # each model gets its own directory to avoid lock fights
68
  p = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_"))
69
  os.makedirs(p, exist_ok=True)
70
  return p
71
 
72
  def safe_pipeline(task: str, model_id: str):
73
- """
74
- Build a pipeline that caches to /tmp per model.
75
- Some pipelines reject 'cache_dir' -> retry without it.
76
- Also handles rare permission/lock races by a short retry.
77
- """
78
  cache_dir = _model_cache_dir(model_id)
79
  print(f"[init] task={task} model={model_id} cache={cache_dir}")
80
- # Try with cache_dir
81
  try:
82
  return pipeline(task, model=model_id, cache_dir=cache_dir,
83
  trust_remote_code=True, device=-1)
84
  except ValueError as e:
85
- # Some models complain: "model_kwargs not used: ['cache_dir']"
86
  if "cache_dir" in str(e):
87
  print(f"[init] {model_id} rejects cache_dir, retrying without it")
88
  return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
89
  raise
90
  except OSError as e:
91
- # Permission/lock race — wait and retry once
92
  print(f"[init] OSError on {model_id}: {e}; retrying once")
93
  time.sleep(1.5)
94
- # Re-assert env (some libs re-read)
95
  os.environ["HF_HOME"] = f"{BASE}/hf"
96
  os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
97
  os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
@@ -107,9 +88,6 @@ def safe_pipeline(task: str, model_id: str):
107
  raise
108
 
109
  def get_model(task: str):
110
- """
111
- Load ONLY the model needed for this task.
112
- """
113
  global _summarizer, _rewriter, _proofreader, _code_explainer
114
  if task == "summarize":
115
  if _summarizer is None:
@@ -129,24 +107,20 @@ def get_model(task: str):
129
  return _code_explainer, "Salesforce/codet5p-220m"
130
  raise ValueError(f"Unsupported task '{task}'")
131
 
132
- # =========================
133
- # Request schema
134
- # =========================
135
  class InputData(BaseModel):
136
- task: str # summarize | rewrite | proofread | explain_code
137
  input: str
138
  params: dict | None = None
139
 
140
  def _clean_params(params: dict | None):
141
- # Block params that some pipelines reject in generate/forward
142
  forbidden = {"cache_dir"}
143
  return {k: v for k, v in (params or {}).items() if k not in forbidden}
144
 
145
- # =========================
146
- # Core endpoint
147
- # =========================
148
  @app.post("/run")
149
  async def run_task(data: InputData):
 
 
 
150
  t0 = time.time()
151
  task = (data.task or "").strip().lower()
152
  text = (data.input or "").strip()
@@ -156,7 +130,6 @@ async def run_task(data: InputData):
156
  if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
157
  return {"error": f"Unsupported task '{task}'."}
158
 
159
- # load only what we need
160
  try:
161
  model, model_used = get_model(task)
162
  except Exception as e:
@@ -183,7 +156,6 @@ async def run_task(data: InputData):
183
  out = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"]
184
 
185
  except Exception as e:
186
- # print full stack to logs for debugging; return friendly message to client
187
  import traceback
188
  print(traceback.format_exc())
189
  return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
@@ -193,4 +165,4 @@ async def run_task(data: InputData):
193
  "output": out,
194
  "model": model_used,
195
  "latencyMs": round((time.time() - t0) * 1000, 2),
196
- }
 
1
+ import os
2
+ os.environ["HOME"] = "/tmp"
 
 
 
 
 
 
 
 
 
3
  BASE = "/tmp"
4
  os.environ["HF_HOME"] = f"{BASE}/hf"
5
  os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
 
19
  ):
20
  os.makedirs(d, exist_ok=True)
21
 
22
+ import time
23
+ from fastapi import FastAPI
24
+ from pydantic import BaseModel
25
+ from fastapi.middleware.cors import CORSMiddleware
26
+ from transformers import pipeline
27
+
28
  app = FastAPI(title="EduPrompt API")
29
  app.add_middleware(
30
  CORSMiddleware,
31
+ allow_origins=["*"],
32
  allow_methods=["*"],
33
  allow_headers=["*"],
34
  )
35
 
36
  @app.get("/")
37
  def health():
 
38
  writable = True
39
  try:
40
  with open(f"{BASE}/eduprompt_write_test.txt", "w") as f:
 
46
  "service": "eduprompt-api",
47
  "tmpWritable": writable,
48
  "TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"],
49
+ "HOME": os.environ["HOME"],
50
  }
51
 
 
 
 
52
  _summarizer = None
53
  _rewriter = None
54
  _proofreader = None
55
  _code_explainer = None
56
 
57
  def _model_cache_dir(model_id: str) -> str:
 
58
  p = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_"))
59
  os.makedirs(p, exist_ok=True)
60
  return p
61
 
62
  def safe_pipeline(task: str, model_id: str):
 
 
 
 
 
63
  cache_dir = _model_cache_dir(model_id)
64
  print(f"[init] task={task} model={model_id} cache={cache_dir}")
 
65
  try:
66
  return pipeline(task, model=model_id, cache_dir=cache_dir,
67
  trust_remote_code=True, device=-1)
68
  except ValueError as e:
 
69
  if "cache_dir" in str(e):
70
  print(f"[init] {model_id} rejects cache_dir, retrying without it")
71
  return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
72
  raise
73
  except OSError as e:
 
74
  print(f"[init] OSError on {model_id}: {e}; retrying once")
75
  time.sleep(1.5)
 
76
  os.environ["HF_HOME"] = f"{BASE}/hf"
77
  os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
78
  os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
 
88
  raise
89
 
90
  def get_model(task: str):
 
 
 
91
  global _summarizer, _rewriter, _proofreader, _code_explainer
92
  if task == "summarize":
93
  if _summarizer is None:
 
107
  return _code_explainer, "Salesforce/codet5p-220m"
108
  raise ValueError(f"Unsupported task '{task}'")
109
 
 
 
 
110
  class InputData(BaseModel):
111
+ task: str
112
  input: str
113
  params: dict | None = None
114
 
115
  def _clean_params(params: dict | None):
 
116
  forbidden = {"cache_dir"}
117
  return {k: v for k, v in (params or {}).items() if k not in forbidden}
118
 
 
 
 
119
  @app.post("/run")
120
  async def run_task(data: InputData):
121
+ print("TRANSFORMERS_CACHE:", os.environ.get("TRANSFORMERS_CACHE"))
122
+ print("HOME:", os.environ.get("HOME"))
123
+ print("Current user:", os.getuid() if hasattr(os, "getuid") else "unknown")
124
  t0 = time.time()
125
  task = (data.task or "").strip().lower()
126
  text = (data.input or "").strip()
 
130
  if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
131
  return {"error": f"Unsupported task '{task}'."}
132
 
 
133
  try:
134
  model, model_used = get_model(task)
135
  except Exception as e:
 
156
  out = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"]
157
 
158
  except Exception as e:
 
159
  import traceback
160
  print(traceback.format_exc())
161
  return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
 
165
  "output": out,
166
  "model": model_used,
167
  "latencyMs": round((time.time() - t0) * 1000, 2),
168
+ }