eduprompt-api / app.py
adityabalaji's picture
Update app.py
3fa0119 verified
import os
import time
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
# =========================
# Hard-force ALL caches to /tmp (writable on Spaces)
# =========================
os.environ["HOME"] = "/tmp"
BASE = "/tmp"
os.environ["HF_HOME"] = f"{BASE}/hf"
os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
os.environ["HUGGINGFACE_HUB_CACHE"] = f"{BASE}/hf"
os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
os.environ["XDG_CACHE_HOME"] = f"{BASE}/xdg"
os.environ["TORCH_HOME"] = f"{BASE}/torch"
os.environ["SENTENCEPIECE_CACHE"] = f"{BASE}/sp"
for d in (
os.environ["HF_HOME"],
os.environ["HF_HUB_CACHE"],
os.environ["HUGGINGFACE_HUB_CACHE"],
os.environ["TRANSFORMERS_CACHE"],
os.environ["XDG_CACHE_HOME"],
os.environ["TORCH_HOME"],
os.environ["SENTENCEPIECE_CACHE"],
):
os.makedirs(d, exist_ok=True)
# =========================
# FastAPI app + CORS
# =========================
app = FastAPI(title="EduPrompt API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten in prod
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def health():
writable = True
try:
with open(f"{BASE}/eduprompt_write_test.txt", "w") as f:
f.write("ok")
except Exception:
writable = False
return {
"ok": True,
"service": "eduprompt-api",
"tmpWritable": writable,
"TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"],
"HOME": os.environ["HOME"],
}
# =========================
# Lazy singletons (loaded per task)
# =========================
_summarizer = None
_rewriter = None
_proofreader = None
_code_explainer = None
def _model_cache_dir(model_id: str) -> str:
# each model gets its own directory to avoid lock fights
p = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_"))
os.makedirs(p, exist_ok=True)
return p
def safe_pipeline(task: str, model_id: str):
"""
Build a pipeline that caches to /tmp per model.
Some pipelines reject 'cache_dir' -> retry without it.
Also handles rare permission/lock races by a short retry.
"""
cache_dir = _model_cache_dir(model_id)
print(f"[init] task={task} model={model_id} cache={cache_dir}")
# Try with cache_dir
try:
return pipeline(task, model=model_id, cache_dir=cache_dir,
trust_remote_code=True, device=-1)
except ValueError as e:
# Some models complain: "model_kwargs not used: ['cache_dir']"
if "cache_dir" in str(e):
print(f"[init] {model_id} rejects cache_dir, retrying without it")
return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
raise
except OSError as e:
# Permission/lock race — wait and retry once
print(f"[init] OSError on {model_id}: {e}; retrying once")
time.sleep(1.5)
# Re-assert env (some libs re-read)
os.environ["HF_HOME"] = f"{BASE}/hf"
os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
try:
return pipeline(task, model=model_id, cache_dir=cache_dir,
trust_remote_code=True, device=-1)
except ValueError as e2:
if "cache_dir" in str(e2):
print(f"[init] {model_id} rejects cache_dir on retry, fallback no cache_dir")
return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
raise
except Exception as e2:
raise
def get_model(task: str):
"""
Load ONLY the model needed for this task.
"""
global _summarizer, _rewriter, _proofreader, _code_explainer
if task == "summarize":
if _summarizer is None:
_summarizer = safe_pipeline("summarization", "t5-small")
return _summarizer, "t5-small"
if task == "rewrite":
if _rewriter is None:
_rewriter = safe_pipeline("text2text-generation", "google/flan-t5-small")
return _rewriter, "google/flan-t5-small"
if task == "proofread":
if _proofreader is None:
_proofreader = safe_pipeline("text2text-generation", "google/flan-t5-small")
return _proofreader, "google/flan-t5-small"
if task == "explain_code":
if _code_explainer is None:
_code_explainer = safe_pipeline("text2text-generation", "Salesforce/codet5p-220m")
return _code_explainer, "Salesforce/codet5p-220m"
raise ValueError(f"Unsupported task '{task}'")
# =========================
# Request schema
# =========================
class InputData(BaseModel):
task: str # summarize | rewrite | proofread | explain_code
input: str
params: dict | None = None
def _clean_params(params: dict | None):
# Block params that some pipelines reject in generate/forward
forbidden = {"cache_dir"}
return {k: v for k, v in (params or {}).items() if k not in forbidden}
# =========================
# Core endpoint
# =========================
@app.post("/run")
async def run_task(data: InputData):
t0 = time.time()
task = (data.task or "").strip().lower()
text = (data.input or "").strip()
if not text:
return {"error": "Empty input text."}
if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
return {"error": f"Unsupported task '{task}'."}
# load only what we need
try:
model, model_used = get_model(task)
except Exception as e:
return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
params = _clean_params(data.params)
params.pop("cache_dir", None) # <-- This line guarantees it's gone
print("Params passed to model:", params)
try:
if task == "summarize":
prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
out = model(prompt, max_length=120, min_length=30,
truncation=True, do_sample=False, **params)[0]["summary_text"]
elif task == "rewrite":
prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
elif task == "proofread":
prompt = f"Correct and improve grammar and style:\n{text}"
out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
else: # explain_code
prompt = f"Explain what this code does in simple language:\n{text}"
out = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"]
except Exception as e:
# print full stack to logs for debugging; return friendly message to client
import traceback
print(traceback.format_exc())
return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
return {
"enhancedPrompt": prompt,
"output": out,
"model": model_used,
"latencyMs": round((time.time() - t0) * 1000, 2),
}