adityabalaji commited on
Commit
9753dd0
·
verified ·
1 Parent(s): aecf872

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -19
app.py CHANGED
@@ -1,23 +1,26 @@
1
- # app.py — EduPrompt FastAPI backend
2
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import pipeline
6
  import time
7
  import os
 
8
 
9
- # --- Hugging Face cache fix for Spaces (permission-safe) ---
10
  os.environ["HF_HOME"] = "/tmp"
11
  os.environ["HF_HUB_CACHE"] = "/tmp"
12
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
13
 
14
- # --- FastAPI app + CORS ---
 
 
15
  app = FastAPI(title="EduPrompt API")
16
 
17
  from fastapi.middleware.cors import CORSMiddleware
18
  app.add_middleware(
19
  CORSMiddleware,
20
- allow_origins=["*"],
21
  allow_methods=["*"],
22
  allow_headers=["*"],
23
  )
@@ -26,31 +29,67 @@ app.add_middleware(
26
  def health():
27
  return {"ok": True, "service": "eduprompt-api"}
28
 
29
- # --- Lazy model loaders ---
30
  _summarizer = None
31
  _rewriter = None
32
  _proofreader = None
33
  _code_explainer = None
34
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  def get_models():
36
  global _summarizer, _rewriter, _proofreader, _code_explainer
37
  if _summarizer is None:
38
- _summarizer = pipeline("summarization", model="t5-small")
39
  if _rewriter is None:
40
- _rewriter = pipeline("text2text-generation", model="google/flan-t5-small")
41
  if _proofreader is None:
42
- _proofreader = pipeline("text2text-generation", model="google/flan-t5-small")
43
  if _code_explainer is None:
44
- _code_explainer = pipeline("text2text-generation", model="Salesforce/codet5p-220m")
45
  return _summarizer, _rewriter, _proofreader, _code_explainer
46
 
47
- # --- Request schema ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  class InputData(BaseModel):
49
- task: str
50
- input: str
51
  params: dict | None = None
52
 
53
- # --- Core endpoint ---
54
  @app.post("/run")
55
  async def run_task(data: InputData):
56
  start = time.time()
@@ -62,28 +101,50 @@ async def run_task(data: InputData):
62
  if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
63
  return {"error": f"Unsupported task '{task}'."}
64
 
65
- enhanced = ""
66
  try:
67
  summarizer, rewriter, proofreader, code_explainer = get_models()
 
 
68
 
 
 
69
  if task == "summarize":
70
  enhanced = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
71
- out = summarizer(enhanced, max_length=120, min_length=30, do_sample=False, truncation=True)[0]["summary_text"]
 
 
 
 
 
 
72
  model_id = "t5-small"
73
 
74
  elif task == "rewrite":
75
  enhanced = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
76
- out = rewriter(enhanced, max_new_tokens=150, truncation=True)[0]["generated_text"]
 
 
 
 
77
  model_id = "google/flan-t5-small"
78
 
79
  elif task == "proofread":
80
  enhanced = f"You are a grammar and style editor. Correct and improve this text:\n{text}"
81
- out = proofreader(enhanced, max_new_tokens=150, truncation=True)[0]["generated_text"]
 
 
 
 
82
  model_id = "google/flan-t5-small"
83
 
84
- else:
85
  enhanced = f"You are a programming tutor. Explain what this code does in simple language:\n{text}"
86
- out = code_explainer(enhanced, max_new_tokens=200, truncation=True)[0]["generated_text"]
 
 
 
 
87
  model_id = "Salesforce/codet5p-220m"
88
 
89
  except Exception as e:
 
1
+ # app.py — EduPrompt FastAPI backend (HF Spaces-safe)
2
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import pipeline
6
  import time
7
  import os
8
+ import asyncio
9
 
10
+ # ---- Hugging Face cache: force writable dir on Spaces ----
11
  os.environ["HF_HOME"] = "/tmp"
12
  os.environ["HF_HUB_CACHE"] = "/tmp"
13
  os.environ["TRANSFORMERS_CACHE"] = "/tmp"
14
 
15
+ CACHE_DIR = "/tmp" # single source of truth
16
+
17
+ # ---- FastAPI + CORS ----
18
  app = FastAPI(title="EduPrompt API")
19
 
20
  from fastapi.middleware.cors import CORSMiddleware
21
  app.add_middleware(
22
  CORSMiddleware,
23
+ allow_origins=["*"], # set your domain(s) in production
24
  allow_methods=["*"],
25
  allow_headers=["*"],
26
  )
 
29
  def health():
30
  return {"ok": True, "service": "eduprompt-api"}
31
 
32
+ # ---- Lazy models (loaded on first use), with cache_dir enforced ----
33
  _summarizer = None
34
  _rewriter = None
35
  _proofreader = None
36
  _code_explainer = None
37
 
38
+ def _make_pipeline(task: str, model_id: str):
39
+ """
40
+ Create a HF pipeline that always caches to /tmp (writable on Spaces).
41
+ Retries once on cache-related OSError.
42
+ """
43
+ try:
44
+ return pipeline(task, model=model_id, cache_dir=CACHE_DIR)
45
+ except OSError as e:
46
+ # Rare HF cache race; wait briefly and retry once
47
+ if "/.cache" in str(e) or "PermissionError" in str(e):
48
+ time.sleep(1.5)
49
+ os.environ["HF_HOME"] = CACHE_DIR
50
+ os.environ["HF_HUB_CACHE"] = CACHE_DIR
51
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
52
+ return pipeline(task, model=model_id, cache_dir=CACHE_DIR)
53
+ raise
54
+
55
  def get_models():
56
  global _summarizer, _rewriter, _proofreader, _code_explainer
57
  if _summarizer is None:
58
+ _summarizer = _make_pipeline("summarization", "t5-small") # CPU-friendly
59
  if _rewriter is None:
60
+ _rewriter = _make_pipeline("text2text-generation", "google/flan-t5-small")
61
  if _proofreader is None:
62
+ _proofreader = _make_pipeline("text2text-generation", "google/flan-t5-small")
63
  if _code_explainer is None:
64
+ _code_explainer = _make_pipeline("text2text-generation", "Salesforce/codet5p-220m")
65
  return _summarizer, _rewriter, _proofreader, _code_explainer
66
 
67
+ # (Optional) tiny warmup to trigger downloads after first request to /run
68
+ @app.on_event("startup")
69
+ async def _post_start_note():
70
+ # We don't download at startup to keep boot fast; models load on first call.
71
+ # Leaving this here in case you ever want to warm them:
72
+ # asyncio.create_task(_warm_once())
73
+ pass
74
+
75
+ async def _warm_once():
76
+ try:
77
+ s, r, p, c = get_models()
78
+ _ = s("warm up", max_length=10, min_length=5, do_sample=False)
79
+ _ = r("rewrite: warm up", max_new_tokens=8)
80
+ _ = p("proofread: warm up", max_new_tokens=8)
81
+ _ = c("explain: print(1)", max_new_tokens=12)
82
+ except Exception:
83
+ # Ignore warm errors; real requests will still retry/load.
84
+ pass
85
+
86
+ # ---- Request schema ----
87
  class InputData(BaseModel):
88
+ task: str # summarize | rewrite | proofread | explain_code
89
+ input: str # user text / code
90
  params: dict | None = None
91
 
92
+ # ---- Core endpoint ----
93
  @app.post("/run")
94
  async def run_task(data: InputData):
95
  start = time.time()
 
101
  if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
102
  return {"error": f"Unsupported task '{task}'."}
103
 
104
+ # Load models (lazy, cached to /tmp)
105
  try:
106
  summarizer, rewriter, proofreader, code_explainer = get_models()
107
+ except Exception as e:
108
+ return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
109
 
110
+ enhanced = ""
111
+ try:
112
  if task == "summarize":
113
  enhanced = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
114
+ out = summarizer(
115
+ enhanced,
116
+ max_length=120,
117
+ min_length=30,
118
+ do_sample=False,
119
+ truncation=True
120
+ )[0]["summary_text"]
121
  model_id = "t5-small"
122
 
123
  elif task == "rewrite":
124
  enhanced = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
125
+ out = rewriter(
126
+ enhanced,
127
+ max_new_tokens=150,
128
+ truncation=True
129
+ )[0]["generated_text"]
130
  model_id = "google/flan-t5-small"
131
 
132
  elif task == "proofread":
133
  enhanced = f"You are a grammar and style editor. Correct and improve this text:\n{text}"
134
+ out = proofreader(
135
+ enhanced,
136
+ max_new_tokens=150,
137
+ truncation=True
138
+ )[0]["generated_text"]
139
  model_id = "google/flan-t5-small"
140
 
141
+ else: # explain_code
142
  enhanced = f"You are a programming tutor. Explain what this code does in simple language:\n{text}"
143
+ out = code_explainer(
144
+ enhanced,
145
+ max_new_tokens=200,
146
+ truncation=True
147
+ )[0]["generated_text"]
148
  model_id = "Salesforce/codet5p-220m"
149
 
150
  except Exception as e: