File size: 7,213 Bytes
c884159
3fa0119
 
 
 
 
 
 
 
 
c884159
12e8c98
 
 
 
 
 
 
 
82fc3eb
 
 
 
 
 
 
 
 
12e8c98
9753dd0
3fa0119
 
 
551b5cc
 
 
3fa0119
aecf872
 
551b5cc
 
 
 
82fc3eb
12e8c98
 
 
 
 
a2c1655
 
 
 
82fc3eb
c884159
a2c1655
12e8c98
3fa0119
 
 
aecf872
 
 
 
551b5cc
82fc3eb
3fa0119
82fc3eb
 
 
 
83abf71
3fa0119
 
 
 
 
82fc3eb
 
3fa0119
82fc3eb
 
 
 
3fa0119
82fc3eb
 
 
 
 
3fa0119
82fc3eb
 
3fa0119
82fc3eb
 
 
 
 
 
 
 
 
 
 
 
 
9753dd0
12e8c98
3fa0119
 
 
aecf872
12e8c98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
551b5cc
3fa0119
 
 
551b5cc
3fa0119
83abf71
12e8c98
551b5cc
82fc3eb
3fa0119
7aa47a6
 
 
3fa0119
 
 
551b5cc
 
82fc3eb
12e8c98
 
82fc3eb
aecf872
 
12e8c98
 
551b5cc
3fa0119
aecf872
12e8c98
9753dd0
 
aecf872
82fc3eb
3fa0119
 
7aa47a6
9753dd0
aecf872
83abf71
82fc3eb
 
 
aecf872
12e8c98
82fc3eb
 
aecf872
83abf71
82fc3eb
 
12e8c98
83abf71
82fc3eb
 
aecf872
3fa0119
a2c1655
 
aecf872
551b5cc
 
83abf71
82fc3eb
83abf71
82fc3eb
c884159
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import os
import time
from fastapi import FastAPI
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from transformers import pipeline

# =========================
# Hard-force ALL caches to /tmp (writable on Spaces)
# =========================
os.environ["HOME"] = "/tmp"
BASE = "/tmp"
os.environ["HF_HOME"] = f"{BASE}/hf"
os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
os.environ["HUGGINGFACE_HUB_CACHE"] = f"{BASE}/hf"
os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
os.environ["XDG_CACHE_HOME"] = f"{BASE}/xdg"
os.environ["TORCH_HOME"] = f"{BASE}/torch"
os.environ["SENTENCEPIECE_CACHE"] = f"{BASE}/sp"
for d in (
    os.environ["HF_HOME"],
    os.environ["HF_HUB_CACHE"],
    os.environ["HUGGINGFACE_HUB_CACHE"],
    os.environ["TRANSFORMERS_CACHE"],
    os.environ["XDG_CACHE_HOME"],
    os.environ["TORCH_HOME"],
    os.environ["SENTENCEPIECE_CACHE"],
):
    os.makedirs(d, exist_ok=True)

# =========================
# FastAPI app + CORS
# =========================
app = FastAPI(title="EduPrompt API")
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],   # tighten in prod
    allow_methods=["*"],
    allow_headers=["*"],
)

@app.get("/")
def health():
    writable = True
    try:
        with open(f"{BASE}/eduprompt_write_test.txt", "w") as f:
            f.write("ok")
    except Exception:
        writable = False
    return {
        "ok": True,
        "service": "eduprompt-api",
        "tmpWritable": writable,
        "TRANSFORMERS_CACHE": os.environ["TRANSFORMERS_CACHE"],
        "HOME": os.environ["HOME"],
    }

# =========================
# Lazy singletons (loaded per task)
# =========================
_summarizer = None
_rewriter = None
_proofreader = None
_code_explainer = None

def _model_cache_dir(model_id: str) -> str:
    # each model gets its own directory to avoid lock fights
    p = os.path.join(os.environ["TRANSFORMERS_CACHE"], model_id.replace("/", "_"))
    os.makedirs(p, exist_ok=True)
    return p

def safe_pipeline(task: str, model_id: str):
    """
    Build a pipeline that caches to /tmp per model.
    Some pipelines reject 'cache_dir' -> retry without it.
    Also handles rare permission/lock races by a short retry.
    """
    cache_dir = _model_cache_dir(model_id)
    print(f"[init] task={task} model={model_id} cache={cache_dir}")
    # Try with cache_dir
    try:
        return pipeline(task, model=model_id, cache_dir=cache_dir,
                        trust_remote_code=True, device=-1)
    except ValueError as e:
        # Some models complain: "model_kwargs not used: ['cache_dir']"
        if "cache_dir" in str(e):
            print(f"[init] {model_id} rejects cache_dir, retrying without it")
            return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
        raise
    except OSError as e:
        # Permission/lock race — wait and retry once
        print(f"[init] OSError on {model_id}: {e}; retrying once")
        time.sleep(1.5)
        # Re-assert env (some libs re-read)
        os.environ["HF_HOME"] = f"{BASE}/hf"
        os.environ["HF_HUB_CACHE"] = f"{BASE}/hf"
        os.environ["TRANSFORMERS_CACHE"] = f"{BASE}/hf/transformers"
        try:
            return pipeline(task, model=model_id, cache_dir=cache_dir,
                            trust_remote_code=True, device=-1)
        except ValueError as e2:
            if "cache_dir" in str(e2):
                print(f"[init] {model_id} rejects cache_dir on retry, fallback no cache_dir")
                return pipeline(task, model=model_id, trust_remote_code=True, device=-1)
            raise
        except Exception as e2:
            raise

def get_model(task: str):
    """
    Load ONLY the model needed for this task.
    """
    global _summarizer, _rewriter, _proofreader, _code_explainer
    if task == "summarize":
        if _summarizer is None:
            _summarizer = safe_pipeline("summarization", "t5-small")
        return _summarizer, "t5-small"
    if task == "rewrite":
        if _rewriter is None:
            _rewriter = safe_pipeline("text2text-generation", "google/flan-t5-small")
        return _rewriter, "google/flan-t5-small"
    if task == "proofread":
        if _proofreader is None:
            _proofreader = safe_pipeline("text2text-generation", "google/flan-t5-small")
        return _proofreader, "google/flan-t5-small"
    if task == "explain_code":
        if _code_explainer is None:
            _code_explainer = safe_pipeline("text2text-generation", "Salesforce/codet5p-220m")
        return _code_explainer, "Salesforce/codet5p-220m"
    raise ValueError(f"Unsupported task '{task}'")

# =========================
# Request schema
# =========================
class InputData(BaseModel):
    task: str                  # summarize | rewrite | proofread | explain_code
    input: str
    params: dict | None = None

def _clean_params(params: dict | None):
    # Block params that some pipelines reject in generate/forward
    forbidden = {"cache_dir"}
    return {k: v for k, v in (params or {}).items() if k not in forbidden}

# =========================
# Core endpoint
# =========================
@app.post("/run")
async def run_task(data: InputData):
    t0 = time.time()
    task = (data.task or "").strip().lower()
    text = (data.input or "").strip()

    if not text:
        return {"error": "Empty input text."}
    if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
        return {"error": f"Unsupported task '{task}'."}

    # load only what we need
    try:
        model, model_used = get_model(task)
    except Exception as e:
        return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}

    params = _clean_params(data.params)
    params.pop("cache_dir", None)  # <-- This line guarantees it's gone
    print("Params passed to model:", params)

    try:
        if task == "summarize":
            prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
            out = model(prompt, max_length=120, min_length=30,
                        truncation=True, do_sample=False, **params)[0]["summary_text"]

        elif task == "rewrite":
            prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
            out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]

        elif task == "proofread":
            prompt = f"Correct and improve grammar and style:\n{text}"
            out = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]

        else:  # explain_code
            prompt = f"Explain what this code does in simple language:\n{text}"
            out = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"]

    except Exception as e:
        # print full stack to logs for debugging; return friendly message to client
        import traceback
        print(traceback.format_exc())
        return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}

    return {
        "enhancedPrompt": prompt,
        "output": out,
        "model": model_used,
        "latencyMs": round((time.time() - t0) * 1000, 2),
    }