import os import time from fastapi import FastAPI from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from transformers import pipeline # ========================= # Hard-force ALL caches to /tmp (writable on Spaces) # ========================= os.environ["HOME"] = "/tmp" BASE = "/tmp" os.environ["HF_HOME"] = f"{BASE}/hf" os.environ["HF_HUB_CACHE"] = f"{BASE}/hf" os.environ["HUGGINGFACE_HUB_CACHE"] = f"{BASE}/hf" os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers" os.environ["XDG_CACHE_HOME"] = f"{BASE}/xdg" os.environ["TORCH_HOME"] = f"{BASE}/torch" os.environ["SENTENCEPIECE_CACHE"] = f"{BASE}/sp" for d in ( os.environ["HF_HOME"], os.environ["HF_HUB_CACHE"], os.environ["HUGGINGFACE_HUB_CACHE"], os.environ["TRANSFORMERS_CACHE"], os.environ["XDG_CACHE_HOME"], os.environ["TORCH_HOME"], os.environ["SENTENCEPIECE_CACHE"], ): os.makedirs(d, exist_ok=True) # ========================= # FastAPI app + CORS # ========================= app = FastAPI(title="EduPrompt API") app.add_middleware( CORSMiddleware, allow_origins=["*"], # tighten in prod allow_methods=["*"], allow_headers=["*"], ) @app.get("/") def health(): writable = True try: with open(f"{BASE}/eduprompt_write_test.txt", "w") as f: f.write("ok") except Exception: writable = False return { "ok": True, "service": "eduprompt-api", "tmpWritable": writable, "TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"], "HOME": os.environ["HOME"], } # ========================= # Lazy singletons (loaded per task) # ========================= _summarizer = None _rewriter = None _proofreader = None _code_explainer = None def _model_cache_dir(model_id: str) -> str: # each model gets its own directory to avoid lock fights p = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_")) os.makedirs(p, exist_ok=True) return p def safe_pipeline(task: str, model_id: str): """ Build a pipeline that caches to /tmp per model. Some pipelines reject 'cache_dir' -> retry without it. Also handles rare permission/lock races by a short retry. """ cache_dir = _model_cache_dir(model_id) print(f"[init] task={task} model={model_id} cache={cache_dir}") # Try with cache_dir try: return pipeline(task, model=model_id, cache_dir=cache_dir, trust_remote_code=True, device=-1) except ValueError as e: # Some models complain: "model_kwargs not used: ['cache_dir']" if "cache_dir" in str(e): print(f"[init] {model_id} rejects cache_dir, retrying without it") return pipeline(task, model=model_id, trust_remote_code=True, device=-1) raise except OSError as e: # Permission/lock race — wait and retry once print(f"[init] OSError on {model_id}: {e}; retrying once") time.sleep(1.5) # Re-assert env (some libs re-read) os.environ["HF_HOME"] = f"{BASE}/hf" os.environ["HF_HUB_CACHE"] = f"{BASE}/hf" os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers" try: return pipeline(task, model=model_id, cache_dir=cache_dir, trust_remote_code=True, device=-1) except ValueError as e2: if "cache_dir" in str(e2): print(f"[init] {model_id} rejects cache_dir on retry, fallback no cache_dir") return pipeline(task, model=model_id, trust_remote_code=True, device=-1) raise except Exception as e2: raise def get_model(task: str): """ Load ONLY the model needed for this task. """ global _summarizer, _rewriter, _proofreader, _code_explainer if task == "summarize": if _summarizer is None: _summarizer = safe_pipeline("summarization", "t5-small") return _summarizer, "t5-small" if task == "rewrite": if _rewriter is None: _rewriter = safe_pipeline("text2text-generation", "google/flan-t5-small") return _rewriter, "google/flan-t5-small" if task == "proofread": if _proofreader is None: _proofreader = safe_pipeline("text2text-generation", "google/flan-t5-small") return _proofreader, "google/flan-t5-small" if task == "explain_code": if _code_explainer is None: _code_explainer = safe_pipeline("text2text-generation", "Salesforce/codet5p-220m") return _code_explainer, "Salesforce/codet5p-220m" raise ValueError(f"Unsupported task '{task}'") # ========================= # Request schema # ========================= class InputData(BaseModel): task: str # summarize | rewrite | proofread | explain_code input: str params: dict | None = None def _clean_params(params: dict | None): # Block params that some pipelines reject in generate/forward forbidden = {"cache_dir"} return {k: v for k, v in (params or {}).items() if k not in forbidden} # ========================= # Core endpoint # ========================= @app.post("/run") async def run_task(data: InputData): t0 = time.time() task = (data.task or "").strip().lower() text = (data.input or "").strip() if not text: return {"error": "Empty input text."} if task not in {"summarize", "rewrite", "proofread", "explain_code"}: return {"error": f"Unsupported task '{task}'."} # load only what we need try: model, model_used = get_model(task) except Exception as e: return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"} params = _clean_params(data.params) params.pop("cache_dir", None) # <-- This line guarantees it's gone print("Params passed to model:", params) try: if task == "summarize": prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}" out = model(prompt, max_length=120, min_length=30, truncation=True, do_sample=False, **params)[0]["summary_text"] elif task == "rewrite": prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}" out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"] elif task == "proofread": prompt = f"Correct and improve grammar and style:\n{text}" out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"] else: # explain_code prompt = f"Explain what this code does in simple language:\n{text}" out = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"] except Exception as e: # print full stack to logs for debugging; return friendly message to client import traceback print(traceback.format_exc()) return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"} return { "enhancedPrompt": prompt, "output": out, "model": model_used, "latencyMs": round((time.time() - t0) * 1000, 2), }