adityabalaji commited on
Commit
7aa47a6
·
verified ·
1 Parent(s): a2c1655

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -1,4 +1,4 @@
1
- # app.py — EduPrompt API (per-task lazy load + cache-safe on Spaces, debug prints)
2
 
3
  import os
4
 
@@ -87,6 +87,11 @@ class InputData(BaseModel):
87
  input: str
88
  params: dict | None = None
89
 
 
 
 
 
 
90
  @app.post("/run")
91
  async def run_task(data: InputData):
92
  start = time.time()
@@ -105,19 +110,22 @@ async def run_task(data: InputData):
105
  print(traceback.format_exc())
106
  return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
107
 
 
 
 
108
  try:
109
  if task == "summarize":
110
  prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
111
- output = model(prompt, max_length=120, min_length=30, truncation=True, do_sample=False)[0]["summary_text"]
112
  elif task == "rewrite":
113
  prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
114
- output = model(prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
115
  elif task == "proofread":
116
  prompt = f"Correct and improve grammar and style:\n{text}"
117
- output = model(prompt, max_new_tokens=150, truncation=True)[0]["generated_text"]
118
  else: # explain_code
119
  prompt = f"Explain what this code does in simple language:\n{text}"
120
- output = model(prompt, max_new_tokens=200, truncation=True)[0]["generated_text"]
121
  except Exception as e:
122
  import traceback
123
  print(traceback.format_exc())
 
1
+ # app.py — EduPrompt API (per-task lazy load, cache-safe, no cache_dir in inference)
2
 
3
  import os
4
 
 
87
  input: str
88
  params: dict | None = None
89
 
90
+ def filter_model_kwargs(params):
91
+ # Remove keys not accepted by model.__call__()
92
+ forbidden = {"cache_dir"}
93
+ return {k: v for k, v in (params or {}).items() if k not in forbidden}
94
+
95
  @app.post("/run")
96
  async def run_task(data: InputData):
97
  start = time.time()
 
110
  print(traceback.format_exc())
111
  return {"error": f"model_load_failed: {type(e).__name__}: {str(e)}"}
112
 
113
+ # Filter out forbidden kwargs
114
+ params = filter_model_kwargs(data.params)
115
+
116
  try:
117
  if task == "summarize":
118
  prompt = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
119
+ output = model(prompt, max_length=120, min_length=30, truncation=True, do_sample=False, **params)[0]["summary_text"]
120
  elif task == "rewrite":
121
  prompt = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
122
+ output = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
123
  elif task == "proofread":
124
  prompt = f"Correct and improve grammar and style:\n{text}"
125
+ output = model(prompt, max_new_tokens=150, truncation=True, **params)[0]["generated_text"]
126
  else: # explain_code
127
  prompt = f"Explain what this code does in simple language:\n{text}"
128
+ output = model(prompt, max_new_tokens=200, truncation=True, **params)[0]["generated_text"]
129
  except Exception as e:
130
  import traceback
131
  print(traceback.format_exc())