Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from transformers import pipeline | |
| # ========================= | |
| # Hard-force ALL caches to /tmp (writable on Spaces) | |
| # ========================= | |
| os.environ["HOME"] = "/tmp" | |
| BASE = "/tmp" | |
| os.environ["HF_HOME"] = f"{BASE}/hf" | |
| os.environ["HF_HUB_CACHE"] = f"{BASE}/hf" | |
| os.environ["HUGGINGFACE_HUB_CACHE"] = f"{BASE}/hf" | |
| os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers" | |
| os.environ["XDG_CACHE_HOME"] = f"{BASE}/xdg" | |
| os.environ["TORCH_HOME"] = f"{BASE}/torch" | |
| os.environ["SENTENCEPIECE_CACHE"] = f"{BASE}/sp" | |
| for d in ( | |
| os.environ["HF_HOME"], | |
| os.environ["HF_HUB_CACHE"], | |
| os.environ["HUGGINGFACE_HUB_CACHE"], | |
| os.environ["TRANSFORMERS_CACHE"], | |
| os.environ["XDG_CACHE_HOME"], | |
| os.environ["TORCH_HOME"], | |
| os.environ["SENTENCEPIECE_CACHE"], | |
| ): | |
| os.makedirs(d, exist_ok=True) | |
| # ========================= | |
| # FastAPI app + CORS | |
| # ========================= | |
| app = FastAPI(title="EduPrompt API") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # tighten in prod | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def health(): | |
| writable = True | |
| try: | |
| with open(f"{BASE}/eduprompt_write_test.txt", "w") as f: | |
| f.write("ok") | |
| except Exception: | |
| writable = False | |
| return { | |
| "ok": True, | |
| "service": "eduprompt-api", | |
| "tmpWritable": writable, | |
| "TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"], | |
| "HOME": os.environ["HOME"], | |
| } | |
| # ========================= | |
| # Lazy singletons (loaded per task) | |
| # ========================= | |
| _summarizer = None | |
| _rewriter = None | |
| _proofreader = None | |
| _code_explainer = None | |
| def _model_cache_dir(model_id: str) -> str: | |
| # each model gets its own directory to avoid lock fights | |
| p = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_")) | |
| os.makedirs(p, exist_ok=True) | |
| return p | |
| def safe_pipeline(task: str, model_id: str): | |
| """ | |
| Build a pipeline that caches to /tmp per model. | |
| Some pipelines reject 'cache_dir' -> retry without it. | |
| Also handles rare permission/lock races by a short retry. | |
| """ | |
| cache_dir = _model_cache_dir(model_id) | |
| print(f"[init] task={task} model={model_id} cache={cache_dir}") | |
| # Try with cache_dir | |
| try: | |
| return pipeline(task, model=model_id, cache_dir=cache_dir, | |
| trust_remote_code=True, device=-1) | |
| except ValueError as e: | |
| # Some models complain: "model_kwargs not used: ['cache_dir']" | |
| if "cache_dir" in str(e): | |
| print(f"[init] {model_id} rejects cache_dir, retrying without it") | |
| return pipeline(task, model=model_id, trust_remote_code=True, device=-1) | |
| raise | |
| except OSError as e: | |
| # Permission/lock race — wait and retry once | |
| print(f"[init] OSError on {model_id}: {e}; retrying once") | |
| time.sleep(1.5) | |
| # Re-assert env (some libs re-read) | |
| os.environ["HF_HOME"] = f"{BASE}/hf" | |
| os.environ["HF_HUB_CACHE"] = f"{BASE}/hf" | |
| os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers" | |
| try: | |
| return pipeline(task, model=model_id, cache_dir=cache_dir, | |
| trust_remote_code=True, device=-1) | |
| except ValueError as e2: | |
| if "cache_dir" in str(e2): | |
| print(f"[init] {model_id} rejects cache_dir on retry, fallback no cache_dir") | |
| return pipeline(task, model=model_id, trust_remote_code=True, device=-1) | |
| raise | |
| except Exception as e2: | |
| raise | |
| def get_model(task: str): | |
| """ | |
| Load ONLY the model needed for this task. | |
| """ | |
| global _summarizer, _rewriter, _proofreader, _code_explainer | |
| if task == "summarize": | |
| if _summarizer is None: | |
| _summarizer = safe_pipeline("summarization", "t5-small") | |
| return _summarizer, "t5-small" | |
| if task == "rewrite": | |
| if _rewriter is None: | |
| _rewriter = safe_pipeline("text2text-generation", "google/flan-t5-small") | |
| return _rewriter, "google/flan-t5-small" | |
| if task == "proofread": | |
| if _proofreader is None: | |
| _proofreader = safe_pipeline("text2text-generation", "google/flan-t5-small") | |
| return _proofreader, "google/flan-t5-small" | |
| if task == "explain_code": | |
| if _code_explainer is None: | |
| _code_explainer = safe_pipeline("text2text-generation", "Salesforce/codet5p-220m") | |
| return _code_explainer, "Salesforce/codet5p-220m" | |
| raise ValueError(f"Unsupported task '{task}'") | |
| # ========================= | |
| # Request schema | |
| # ========================= | |
| class InputData(BaseModel): | |
| task: str # summarize | rewrite | proofread | explain_code | |
| input: str | |
| params: dict | None = None | |
| def _clean_params(params: dict | None): | |
| # Block params that some pipelines reject in generate/forward | |
| forbidden = {"cache_dir"} | |
| return {k: v for k, v in (params or {}).items() if k not in forbidden} | |
| # ========================= | |
| # Core endpoint | |
| # ========================= | |
| async def run_task(data: InputData): | |
| t0 = time.time() | |
| task = (data.task or "").strip().lower() | |
| text = (data.input or "").strip() | |
| if not text: | |
| return {"error": "Empty input text."} | |
| if task not in {"summarize", "rewrite", "proofread", "explain_code"}: | |
| return {"error": f"Unsupported task '{task}'."} | |
| # load only what we need | |
| try: | |
| model, model_used = get_model(task) | |
| except Exception as e: | |
| return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"} | |
| params = _clean_params(data.params) | |
| params.pop("cache_dir", None) # <-- This line guarantees it's gone | |
| print("Params passed to model:", params) | |
| try: | |
| if task == "summarize": | |
| prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}" | |
| out = model(prompt, max_length=120, min_length=30, | |
| truncation=True, do_sample=False, **params)[0]["summary_text"] | |
| elif task == "rewrite": | |
| prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}" | |
| out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"] | |
| elif task == "proofread": | |
| prompt = f"Correct and improve grammar and style:\n{text}" | |
| out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"] | |
| else: # explain_code | |
| prompt = f"Explain what this code does in simple language:\n{text}" | |
| out = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"] | |
| except Exception as e: | |
| # print full stack to logs for debugging; return friendly message to client | |
| import traceback | |
| print(traceback.format_exc()) | |
| return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"} | |
| return { | |
| "enhancedPrompt": prompt, | |
| "output": out, | |
| "model": model_used, | |
| "latencyMs": round((time.time() - t0) * 1000, 2), | |
| } |