adityabalaji commited on
Commit
12e8c98
·
verified ·
1 Parent(s): 83abf71

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -57
app.py CHANGED
@@ -1,108 +1,119 @@
1
- # app.py — EduPrompt FastAPI backend (final HF Spaces fix)
2
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
- from transformers import pipeline
6
  from fastapi.middleware.cors import CORSMiddleware
 
7
  import os, time
8
 
9
- # ---- Fix for Hugging Face Spaces cache permissions ----
10
- CACHE_DIR = "/tmp"
11
- os.environ["HF_HOME"] = CACHE_DIR
12
- os.environ["HF_HUB_CACHE"] = CACHE_DIR
13
- os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
 
 
 
 
 
 
 
 
 
 
14
 
15
  app = FastAPI(title="EduPrompt API")
16
-
17
  app.add_middleware(
18
  CORSMiddleware,
19
- allow_origins=["*"],
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
23
 
24
  @app.get("/")
25
  def health():
26
- return {"ok": True, "service": "eduprompt-api"}
27
-
28
- # Lazy models (loaded on first use)
 
 
 
 
 
 
 
29
  _summarizer = None
30
  _rewriter = None
31
  _proofreader = None
32
  _code_explainer = None
33
 
34
  def safe_pipeline(task: str, model_id: str):
35
- """Always download/cache models inside /tmp (Spaces-safe)."""
36
- try:
37
- return pipeline(task, model=model_id, cache_dir=CACHE_DIR)
38
- except OSError:
39
- # Force reset and retry once if cache issue
40
- time.sleep(1)
41
- os.environ["HF_HOME"] = CACHE_DIR
42
- os.environ["HF_HUB_CACHE"] = CACHE_DIR
43
- os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
44
- return pipeline(task, model=model_id, cache_dir=CACHE_DIR)
45
 
46
- def get_models():
 
47
  global _summarizer, _rewriter, _proofreader, _code_explainer
48
- if _summarizer is None:
49
- _summarizer = safe_pipeline("summarization", "t5-small")
50
- if _rewriter is None:
51
- _rewriter = safe_pipeline("text2text-generation", "google/flan-t5-small")
52
- if _proofreader is None:
53
- _proofreader = safe_pipeline("text2text-generation", "google/flan-t5-small")
54
- if _code_explainer is None:
55
- _code_explainer = safe_pipeline("text2text-generation", "Salesforce/codet5p-220m")
56
- return _summarizer, _rewriter, _proofreader, _code_explainer
 
 
 
 
 
 
 
 
57
 
58
  class InputData(BaseModel):
59
- task: str
60
  input: str
 
61
 
62
  @app.post("/run")
63
  async def run_task(data: InputData):
64
  start = time.time()
65
- task = data.task.strip().lower()
66
- text = data.input.strip()
67
  if not text:
68
  return {"error": "Empty input text."}
 
 
69
 
 
70
  try:
71
- summarizer, rewriter, proofreader, code_explainer = get_models()
72
  except Exception as e:
73
  return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
74
 
75
  try:
76
  if task == "summarize":
77
  prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
78
- result = summarizer(prompt, max_length=120, min_length=30, truncation=True)[0]["summary_text"]
79
- model_used = "t5-small"
80
-
81
  elif task == "rewrite":
82
- prompt = f"Rewrite this text for clarity and tone:\n{text}"
83
- result = rewriter(prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
84
- model_used = "google/flan-t5-small"
85
-
86
  elif task == "proofread":
87
  prompt = f"Correct and improve grammar and style:\n{text}"
88
- result = proofreader(prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
89
- model_used = "google/flan-t5-small"
90
-
91
- elif task == "explain_code":
92
  prompt = f"Explain what this code does in simple language:\n{text}"
93
- result = code_explainer(prompt, max_new_tokens=200, truncation=True)[0]["generated_text"]
94
- model_used = "Salesforce/codet5p-220m"
95
-
96
- else:
97
- return {"error": f"Unsupported task '{task}'."}
98
-
99
  except Exception as e:
100
  return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
101
 
102
- latency = round((time.time() - start) * 1000, 2)
103
  return {
104
  "enhancedPrompt": prompt,
105
- "output": result,
106
  "model": model_used,
107
- "latencyMs": latency
108
- }
 
1
+ # app.py — EduPrompt API (per-task lazy load + cache-safe on Spaces)
2
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
 
5
  from fastapi.middleware.cors import CORSMiddleware
6
+ from transformers import pipeline
7
  import os, time
8
 
9
+ # ---------- Force ALL caches to /tmp ----------
10
+ BASE = "/tmp"
11
+ os.environ["HF_HOME"] = f"{BASE}/hf"
12
+ os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
13
+ os.environ["HUGGINGFACE_HUB_CACHE"] = f"{BASE}/hf"
14
+ os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
15
+ os.environ["XDG_CACHE_HOME"] = f"{BASE}/xdg"
16
+ os.environ["TORCH_HOME"] = f"{BASE}/torch"
17
+ os.environ["SENTENCEPIECE_CACHE"] = f"{BASE}/sp"
18
+ for d in [
19
+ os.environ["HF_HOME"], os.environ["HF_HUB_CACHE"], os.environ["HUGGINGFACE_HUB_CACHE"],
20
+ os.environ["TRANSFORMERS_CACHE"], os.environ["XDG_CACHE_HOME"],
21
+ os.environ["TORCH_HOME"], os.environ["SENTENCEPIECE_CACHE"]
22
+ ]:
23
+ os.makedirs(d, exist_ok=True)
24
 
25
  app = FastAPI(title="EduPrompt API")
 
26
  app.add_middleware(
27
  CORSMiddleware,
28
+ allow_origins=["*"], # tighten in prod
29
  allow_methods=["*"],
30
  allow_headers=["*"],
31
  )
32
 
33
  @app.get("/")
34
  def health():
35
+ # prove /tmp is writable
36
+ try:
37
+ with open(f"{BASE}/eduprompt_write_test.txt", "w") as f:
38
+ f.write("ok")
39
+ writable = True
40
+ except Exception:
41
+ writable = False
42
+ return {"ok": True, "service": "eduprompt-api", "tmpWritable": writable}
43
+
44
+ # ---------- lazy singletons ----------
45
  _summarizer = None
46
  _rewriter = None
47
  _proofreader = None
48
  _code_explainer = None
49
 
50
  def safe_pipeline(task: str, model_id: str):
51
+ """Cache every model in its own /tmp subdir. CPU-only."""
52
+ model_cache = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_"))
53
+ os.makedirs(model_cache, exist_ok=True)
54
+ return pipeline(task, model=model_id, cache_dir=model_cache, trust_remote_code=True, device=-1)
 
 
 
 
 
 
55
 
56
+ def get_model(task: str):
57
+ """Load ONLY the model needed for this task."""
58
  global _summarizer, _rewriter, _proofreader, _code_explainer
59
+ if task == "summarize":
60
+ if _summarizer is None:
61
+ _summarizer = safe_pipeline("summarization", "t5-small")
62
+ return _summarizer, "t5-small"
63
+ if task == "rewrite":
64
+ if _rewriter is None:
65
+ _rewriter = safe_pipeline("text2text-generation", "google/flan-t5-small")
66
+ return _rewriter, "google/flan-t5-small"
67
+ if task == "proofread":
68
+ if _proofreader is None:
69
+ _proofreader = safe_pipeline("text2text-generation", "google/flan-t5-small")
70
+ return _proofreader, "google/flan-t5-small"
71
+ if task == "explain_code":
72
+ if _code_explainer is None:
73
+ _code_explainer = safe_pipeline("text2text-generation", "Salesforce/codet5p-220m")
74
+ return _code_explainer, "Salesforce/codet5p-220m"
75
+ raise ValueError(f"Unsupported task '{task}'")
76
 
77
  class InputData(BaseModel):
78
+ task: str # summarize | rewrite | proofread | explain_code
79
  input: str
80
+ params: dict | None = None
81
 
82
  @app.post("/run")
83
  async def run_task(data: InputData):
84
  start = time.time()
85
+ task = (data.task or "").strip().lower()
86
+ text = (data.input or "").strip()
87
  if not text:
88
  return {"error": "Empty input text."}
89
+ if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
90
+ return {"error": f"Unsupported task '{task}'."}
91
 
92
+ # Load only what we need
93
  try:
94
+ model, model_used = get_model(task)
95
  except Exception as e:
96
  return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
97
 
98
  try:
99
  if task == "summarize":
100
  prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
101
+ output = model(prompt, max_length=120, min_length=30, truncation=True, do_sample=False)[0]["summary_text"]
 
 
102
  elif task == "rewrite":
103
+ prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
104
+ output = model(prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
 
 
105
  elif task == "proofread":
106
  prompt = f"Correct and improve grammar and style:\n{text}"
107
+ output = model(prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
108
+ else: # explain_code
 
 
109
  prompt = f"Explain what this code does in simple language:\n{text}"
110
+ output = model(prompt, max_new_tokens=200, truncation=True)[0]["generated_text"]
 
 
 
 
 
111
  except Exception as e:
112
  return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
113
 
 
114
  return {
115
  "enhancedPrompt": prompt,
116
+ "output": output,
117
  "model": model_used,
118
+ "latencyMs": round((time.time() - start) * 1000, 2),
119
+ }