Spaces:
Sleeping
Sleeping
File size: 7,213 Bytes
c884159 3fa0119 c884159 12e8c98 82fc3eb 12e8c98 9753dd0 3fa0119 551b5cc 3fa0119 aecf872 551b5cc 82fc3eb 12e8c98 a2c1655 82fc3eb c884159 a2c1655 12e8c98 3fa0119 aecf872 551b5cc 82fc3eb 3fa0119 82fc3eb 83abf71 3fa0119 82fc3eb 3fa0119 82fc3eb 3fa0119 82fc3eb 3fa0119 82fc3eb 3fa0119 82fc3eb 9753dd0 12e8c98 3fa0119 aecf872 12e8c98 551b5cc 3fa0119 551b5cc 3fa0119 83abf71 12e8c98 551b5cc 82fc3eb 3fa0119 7aa47a6 3fa0119 551b5cc 82fc3eb 12e8c98 82fc3eb aecf872 12e8c98 551b5cc 3fa0119 aecf872 12e8c98 9753dd0 aecf872 82fc3eb 3fa0119 7aa47a6 9753dd0 aecf872 83abf71 82fc3eb aecf872 12e8c98 82fc3eb aecf872 83abf71 82fc3eb 12e8c98 83abf71 82fc3eb aecf872 3fa0119 a2c1655 aecf872 551b5cc 83abf71 82fc3eb 83abf71 82fc3eb c884159 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import os
import time
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline
# =========================
# Hard-force ALL caches to /tmp (writable on Spaces)
# =========================
os.environ["HOME"] = "/tmp"
BASE = "/tmp"
os.environ["HF_HOME"] = f"{BASE}/hf"
os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
os.environ["HUGGINGFACE_HUB_CACHE"] = f"{BASE}/hf"
os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
os.environ["XDG_CACHE_HOME"] = f"{BASE}/xdg"
os.environ["TORCH_HOME"] = f"{BASE}/torch"
os.environ["SENTENCEPIECE_CACHE"] = f"{BASE}/sp"
for d in (
os.environ["HF_HOME"],
os.environ["HF_HUB_CACHE"],
os.environ["HUGGINGFACE_HUB_CACHE"],
os.environ["TRANSFORMERS_CACHE"],
os.environ["XDG_CACHE_HOME"],
os.environ["TORCH_HOME"],
os.environ["SENTENCEPIECE_CACHE"],
):
os.makedirs(d, exist_ok=True)
# =========================
# FastAPI app + CORS
# =========================
app = FastAPI(title="EduPrompt API")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # tighten in prod
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def health():
writable = True
try:
with open(f"{BASE}/eduprompt_write_test.txt", "w") as f:
f.write("ok")
except Exception:
writable = False
return {
"ok": True,
"service": "eduprompt-api",
"tmpWritable": writable,
"TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"],
"HOME": os.environ["HOME"],
}
# =========================
# Lazy singletons (loaded per task)
# =========================
_summarizer = None
_rewriter = None
_proofreader = None
_code_explainer = None
def _model_cache_dir(model_id: str) -> str:
# each model gets its own directory to avoid lock fights
p = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_"))
os.makedirs(p, exist_ok=True)
return p
def safe_pipeline(task: str, model_id: str):
"""
Build a pipeline that caches to /tmp per model.
Some pipelines reject 'cache_dir' -> retry without it.
Also handles rare permission/lock races by a short retry.
"""
cache_dir = _model_cache_dir(model_id)
print(f"[init] task={task} model={model_id} cache={cache_dir}")
# Try with cache_dir
try:
return pipeline(task, model=model_id, cache_dir=cache_dir,
trust_remote_code=True, device=-1)
except ValueError as e:
# Some models complain: "model_kwargs not used: ['cache_dir']"
if "cache_dir" in str(e):
print(f"[init] {model_id} rejects cache_dir, retrying without it")
return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
raise
except OSError as e:
# Permission/lock race — wait and retry once
print(f"[init] OSError on {model_id}: {e}; retrying once")
time.sleep(1.5)
# Re-assert env (some libs re-read)
os.environ["HF_HOME"] = f"{BASE}/hf"
os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
try:
return pipeline(task, model=model_id, cache_dir=cache_dir,
trust_remote_code=True, device=-1)
except ValueError as e2:
if "cache_dir" in str(e2):
print(f"[init] {model_id} rejects cache_dir on retry, fallback no cache_dir")
return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
raise
except Exception as e2:
raise
def get_model(task: str):
"""
Load ONLY the model needed for this task.
"""
global _summarizer, _rewriter, _proofreader, _code_explainer
if task == "summarize":
if _summarizer is None:
_summarizer = safe_pipeline("summarization", "t5-small")
return _summarizer, "t5-small"
if task == "rewrite":
if _rewriter is None:
_rewriter = safe_pipeline("text2text-generation", "google/flan-t5-small")
return _rewriter, "google/flan-t5-small"
if task == "proofread":
if _proofreader is None:
_proofreader = safe_pipeline("text2text-generation", "google/flan-t5-small")
return _proofreader, "google/flan-t5-small"
if task == "explain_code":
if _code_explainer is None:
_code_explainer = safe_pipeline("text2text-generation", "Salesforce/codet5p-220m")
return _code_explainer, "Salesforce/codet5p-220m"
raise ValueError(f"Unsupported task '{task}'")
# =========================
# Request schema
# =========================
class InputData(BaseModel):
task: str # summarize | rewrite | proofread | explain_code
input: str
params: dict | None = None
def _clean_params(params: dict | None):
# Block params that some pipelines reject in generate/forward
forbidden = {"cache_dir"}
return {k: v for k, v in (params or {}).items() if k not in forbidden}
# =========================
# Core endpoint
# =========================
@app.post("/run")
async def run_task(data: InputData):
t0 = time.time()
task = (data.task or "").strip().lower()
text = (data.input or "").strip()
if not text:
return {"error": "Empty input text."}
if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
return {"error": f"Unsupported task '{task}'."}
# load only what we need
try:
model, model_used = get_model(task)
except Exception as e:
return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
params = _clean_params(data.params)
params.pop("cache_dir", None) # <-- This line guarantees it's gone
print("Params passed to model:", params)
try:
if task == "summarize":
prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
out = model(prompt, max_length=120, min_length=30,
truncation=True, do_sample=False, **params)[0]["summary_text"]
elif task == "rewrite":
prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
elif task == "proofread":
prompt = f"Correct and improve grammar and style:\n{text}"
out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
else: # explain_code
prompt = f"Explain what this code does in simple language:\n{text}"
out = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"]
except Exception as e:
# print full stack to logs for debugging; return friendly message to client
import traceback
print(traceback.format_exc())
return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
return {
"enhancedPrompt": prompt,
"output": out,
"model": model_used,
"latencyMs": round((time.time() - t0) * 1000, 2),
} |