adityabalaji commited on
Commit
83abf71
·
verified ·
1 Parent(s): 9753dd0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -98
app.py CHANGED
@@ -1,26 +1,22 @@
1
- # app.py — EduPrompt FastAPI backend (HF Spaces-safe)
2
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import pipeline
6
- import time
7
- import os
8
- import asyncio
9
-
10
- # ---- Hugging Face cache: force writable dir on Spaces ----
11
- os.environ["HF_HOME"] = "/tmp"
12
- os.environ["HF_HUB_CACHE"] = "/tmp"
13
- os.environ["TRANSFORMERS_CACHE"] = "/tmp"
14
 
15
- CACHE_DIR = "/tmp" # single source of truth
 
 
 
 
16
 
17
- # ---- FastAPI + CORS ----
18
  app = FastAPI(title="EduPrompt API")
19
 
20
- from fastapi.middleware.cors import CORSMiddleware
21
  app.add_middleware(
22
  CORSMiddleware,
23
- allow_origins=["*"], # set your domain(s) in production
24
  allow_methods=["*"],
25
  allow_headers=["*"],
26
  )
@@ -29,131 +25,84 @@ app.add_middleware(
29
  def health():
30
  return {"ok": True, "service": "eduprompt-api"}
31
 
32
- # ---- Lazy models (loaded on first use), with cache_dir enforced ----
33
  _summarizer = None
34
  _rewriter = None
35
  _proofreader = None
36
  _code_explainer = None
37
 
38
- def _make_pipeline(task: str, model_id: str):
39
- """
40
- Create a HF pipeline that always caches to /tmp (writable on Spaces).
41
- Retries once on cache-related OSError.
42
- """
43
  try:
44
  return pipeline(task, model=model_id, cache_dir=CACHE_DIR)
45
- except OSError as e:
46
- # Rare HF cache race; wait briefly and retry once
47
- if "/.cache" in str(e) or "PermissionError" in str(e):
48
- time.sleep(1.5)
49
- os.environ["HF_HOME"] = CACHE_DIR
50
- os.environ["HF_HUB_CACHE"] = CACHE_DIR
51
- os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
52
- return pipeline(task, model=model_id, cache_dir=CACHE_DIR)
53
- raise
54
 
55
  def get_models():
56
  global _summarizer, _rewriter, _proofreader, _code_explainer
57
  if _summarizer is None:
58
- _summarizer = _make_pipeline("summarization", "t5-small") # CPU-friendly
59
  if _rewriter is None:
60
- _rewriter = _make_pipeline("text2text-generation", "google/flan-t5-small")
61
  if _proofreader is None:
62
- _proofreader = _make_pipeline("text2text-generation", "google/flan-t5-small")
63
  if _code_explainer is None:
64
- _code_explainer = _make_pipeline("text2text-generation", "Salesforce/codet5p-220m")
65
  return _summarizer, _rewriter, _proofreader, _code_explainer
66
 
67
- # (Optional) tiny warmup to trigger downloads after first request to /run
68
- @app.on_event("startup")
69
- async def _post_start_note():
70
- # We don't download at startup to keep boot fast; models load on first call.
71
- # Leaving this here in case you ever want to warm them:
72
- # asyncio.create_task(_warm_once())
73
- pass
74
-
75
- async def _warm_once():
76
- try:
77
- s, r, p, c = get_models()
78
- _ = s("warm up", max_length=10, min_length=5, do_sample=False)
79
- _ = r("rewrite: warm up", max_new_tokens=8)
80
- _ = p("proofread: warm up", max_new_tokens=8)
81
- _ = c("explain: print(1)", max_new_tokens=12)
82
- except Exception:
83
- # Ignore warm errors; real requests will still retry/load.
84
- pass
85
-
86
- # ---- Request schema ----
87
  class InputData(BaseModel):
88
- task: str # summarize | rewrite | proofread | explain_code
89
- input: str # user text / code
90
- params: dict | None = None
91
 
92
- # ---- Core endpoint ----
93
  @app.post("/run")
94
  async def run_task(data: InputData):
95
  start = time.time()
96
- text = (data.input or "").strip()
97
- task = (data.task or "").strip().lower()
98
-
99
  if not text:
100
  return {"error": "Empty input text."}
101
- if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
102
- return {"error": f"Unsupported task '{task}'."}
103
 
104
- # Load models (lazy, cached to /tmp)
105
  try:
106
  summarizer, rewriter, proofreader, code_explainer = get_models()
107
  except Exception as e:
108
  return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
109
 
110
- enhanced = ""
111
  try:
112
  if task == "summarize":
113
- enhanced = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
114
- out = summarizer(
115
- enhanced,
116
- max_length=120,
117
- min_length=30,
118
- do_sample=False,
119
- truncation=True
120
- )[0]["summary_text"]
121
- model_id = "t5-small"
122
 
123
  elif task == "rewrite":
124
- enhanced = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
125
- out = rewriter(
126
- enhanced,
127
- max_new_tokens=150,
128
- truncation=True
129
- )[0]["generated_text"]
130
- model_id = "google/flan-t5-small"
131
 
132
  elif task == "proofread":
133
- enhanced = f"You are a grammar and style editor. Correct and improve this text:\n{text}"
134
- out = proofreader(
135
- enhanced,
136
- max_new_tokens=150,
137
- truncation=True
138
- )[0]["generated_text"]
139
- model_id = "google/flan-t5-small"
140
-
141
- else: # explain_code
142
- enhanced = f"You are a programming tutor. Explain what this code does in simple language:\n{text}"
143
- out = code_explainer(
144
- enhanced,
145
- max_new_tokens=200,
146
- truncation=True
147
- )[0]["generated_text"]
148
- model_id = "Salesforce/codet5p-220m"
149
 
150
  except Exception as e:
151
  return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
152
 
153
  latency = round((time.time() - start) * 1000, 2)
154
  return {
155
- "enhancedPrompt": enhanced,
156
- "output": out,
157
- "model": model_id,
158
  "latencyMs": latency
159
  }
 
1
+ # app.py — EduPrompt FastAPI backend (final HF Spaces fix)
2
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import pipeline
6
+ from fastapi.middleware.cors import CORSMiddleware
7
+ import os, time
 
 
 
 
 
 
8
 
9
+ # ---- Fix for Hugging Face Spaces cache permissions ----
10
+ CACHE_DIR = "/tmp"
11
+ os.environ["HF_HOME"] = CACHE_DIR
12
+ os.environ["HF_HUB_CACHE"] = CACHE_DIR
13
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
14
 
 
15
  app = FastAPI(title="EduPrompt API")
16
 
 
17
  app.add_middleware(
18
  CORSMiddleware,
19
+ allow_origins=["*"],
20
  allow_methods=["*"],
21
  allow_headers=["*"],
22
  )
 
25
  def health():
26
  return {"ok": True, "service": "eduprompt-api"}
27
 
28
+ # Lazy models (loaded on first use)
29
  _summarizer = None
30
  _rewriter = None
31
  _proofreader = None
32
  _code_explainer = None
33
 
34
+ def safe_pipeline(task: str, model_id: str):
35
+ """Always download/cache models inside /tmp (Spaces-safe)."""
 
 
 
36
  try:
37
  return pipeline(task, model=model_id, cache_dir=CACHE_DIR)
38
+ except OSError:
39
+ # Force reset and retry once if cache issue
40
+ time.sleep(1)
41
+ os.environ["HF_HOME"] = CACHE_DIR
42
+ os.environ["HF_HUB_CACHE"] = CACHE_DIR
43
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
44
+ return pipeline(task, model=model_id, cache_dir=CACHE_DIR)
 
 
45
 
46
  def get_models():
47
  global _summarizer, _rewriter, _proofreader, _code_explainer
48
  if _summarizer is None:
49
+ _summarizer = safe_pipeline("summarization", "t5-small")
50
  if _rewriter is None:
51
+ _rewriter = safe_pipeline("text2text-generation", "google/flan-t5-small")
52
  if _proofreader is None:
53
+ _proofreader = safe_pipeline("text2text-generation", "google/flan-t5-small")
54
  if _code_explainer is None:
55
+ _code_explainer = safe_pipeline("text2text-generation", "Salesforce/codet5p-220m")
56
  return _summarizer, _rewriter, _proofreader, _code_explainer
57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  class InputData(BaseModel):
59
+ task: str
60
+ input: str
 
61
 
 
62
  @app.post("/run")
63
  async def run_task(data: InputData):
64
  start = time.time()
65
+ task = data.task.strip().lower()
66
+ text = data.input.strip()
 
67
  if not text:
68
  return {"error": "Empty input text."}
 
 
69
 
 
70
  try:
71
  summarizer, rewriter, proofreader, code_explainer = get_models()
72
  except Exception as e:
73
  return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
74
 
 
75
  try:
76
  if task == "summarize":
77
+ prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
78
+ result = summarizer(prompt, max_length=120, min_length=30, truncation=True)[0]["summary_text"]
79
+ model_used = "t5-small"
 
 
 
 
 
 
80
 
81
  elif task == "rewrite":
82
+ prompt = f"Rewrite this text for clarity and tone:\n{text}"
83
+ result = rewriter(prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
84
+ model_used = "google/flan-t5-small"
 
 
 
 
85
 
86
  elif task == "proofread":
87
+ prompt = f"Correct and improve grammar and style:\n{text}"
88
+ result = proofreader(prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
89
+ model_used = "google/flan-t5-small"
90
+
91
+ elif task == "explain_code":
92
+ prompt = f"Explain what this code does in simple language:\n{text}"
93
+ result = code_explainer(prompt, max_new_tokens=200, truncation=True)[0]["generated_text"]
94
+ model_used = "Salesforce/codet5p-220m"
95
+
96
+ else:
97
+ return {"error": f"Unsupported task '{task}'."}
 
 
 
 
 
98
 
99
  except Exception as e:
100
  return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
101
 
102
  latency = round((time.time() - start) * 1000, 2)
103
  return {
104
+ "enhancedPrompt": prompt,
105
+ "output": result,
106
+ "model": model_used,
107
  "latencyMs": latency
108
  }