adityabalaji commited on
Commit
aecf872
·
verified ·
1 Parent(s): c026026

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -28
app.py CHANGED
@@ -1,65 +1,98 @@
1
- from fastapi import FastAPI, Request
 
 
2
  from pydantic import BaseModel
3
  from transformers import pipeline
4
  import time
 
 
 
 
 
 
5
 
 
6
  app = FastAPI(title="EduPrompt API")
7
 
8
  from fastapi.middleware.cors import CORSMiddleware
9
-
10
  app.add_middleware(
11
  CORSMiddleware,
12
- allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]
 
 
13
  )
14
 
15
  @app.get("/")
16
  def health():
17
  return {"ok": True, "service": "eduprompt-api"}
18
 
 
 
 
 
 
19
 
20
- # Models
21
- summarizer = pipeline("summarization", model="t5-small")
22
- rewriter = pipeline("text2text-generation", model="google/flan-t5-small")
23
- proofreader = pipeline("text2text-generation", model="google/flan-t5-small")
24
- code_explainer = pipeline("text2text-generation", model="Salesforce/codet5p-220m")
 
 
 
 
 
 
25
 
26
- # Request schema
27
  class InputData(BaseModel):
28
  task: str
29
  input: str
30
  params: dict | None = None
31
 
 
32
  @app.post("/run")
33
  async def run_task(data: InputData):
34
  start = time.time()
35
- text = data.input.strip()
36
- task = data.task.lower()
 
 
 
 
 
37
 
38
- # Build enhanced prompt (simple version)
39
  enhanced = ""
40
- if task == "summarize":
41
- enhanced = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
42
- result = summarizer(enhanced, max_length=120, min_length=30, do_sample=False)[0]["summary_text"]
 
 
 
 
43
 
44
- elif task == "rewrite":
45
- enhanced = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
46
- result = rewriter(enhanced, max_new_tokens=150)[0]["generated_text"]
 
47
 
48
- elif task == "proofread":
49
- enhanced = f"You are a grammar and style editor. Correct and improve this text:\n{text}"
50
- result = proofreader(enhanced, max_new_tokens=150)[0]["generated_text"]
 
51
 
52
- elif task == "explain_code":
53
- enhanced = f"You are a programming tutor. Explain what this code does in simple language:\n{text}"
54
- result = code_explainer(enhanced, max_new_tokens=200)[0]["generated_text"]
 
55
 
56
- else:
57
- return {"error": "Unsupported task"}
58
 
59
  latency = round((time.time() - start) * 1000, 2)
60
  return {
61
  "enhancedPrompt": enhanced,
62
- "output": result,
63
- "model": task,
64
  "latencyMs": latency
65
  }
 
1
+ # app.py EduPrompt FastAPI backend
2
+
3
+ from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from transformers import pipeline
6
  import time
7
+ import os
8
+
9
+ # --- Hugging Face cache fix for Spaces (permission-safe) ---
10
+ os.environ["HF_HOME"] = "/tmp"
11
+ os.environ["HF_HUB_CACHE"] = "/tmp"
12
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp"
13
 
14
+ # --- FastAPI app + CORS ---
15
  app = FastAPI(title="EduPrompt API")
16
 
17
  from fastapi.middleware.cors import CORSMiddleware
 
18
  app.add_middleware(
19
  CORSMiddleware,
20
+ allow_origins=["*"],
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
  )
24
 
25
  @app.get("/")
26
  def health():
27
  return {"ok": True, "service": "eduprompt-api"}
28
 
29
+ # --- Lazy model loaders ---
30
+ _summarizer = None
31
+ _rewriter = None
32
+ _proofreader = None
33
+ _code_explainer = None
34
 
35
+ def get_models():
36
+ global _summarizer, _rewriter, _proofreader, _code_explainer
37
+ if _summarizer is None:
38
+ _summarizer = pipeline("summarization", model="t5-small")
39
+ if _rewriter is None:
40
+ _rewriter = pipeline("text2text-generation", model="google/flan-t5-small")
41
+ if _proofreader is None:
42
+ _proofreader = pipeline("text2text-generation", model="google/flan-t5-small")
43
+ if _code_explainer is None:
44
+ _code_explainer = pipeline("text2text-generation", model="Salesforce/codet5p-220m")
45
+ return _summarizer, _rewriter, _proofreader, _code_explainer
46
 
47
+ # --- Request schema ---
48
  class InputData(BaseModel):
49
  task: str
50
  input: str
51
  params: dict | None = None
52
 
53
+ # --- Core endpoint ---
54
  @app.post("/run")
55
  async def run_task(data: InputData):
56
  start = time.time()
57
+ text = (data.input or "").strip()
58
+ task = (data.task or "").strip().lower()
59
+
60
+ if not text:
61
+ return {"error": "Empty input text."}
62
+ if task not in {"summarize", "rewrite", "proofread", "explain_code"}:
63
+ return {"error": f"Unsupported task '{task}'."}
64
 
 
65
  enhanced = ""
66
+ try:
67
+ summarizer, rewriter, proofreader, code_explainer = get_models()
68
+
69
+ if task == "summarize":
70
+ enhanced = f"You are an expert explainer. Summarize clearly and concisely:\n{text}"
71
+ out = summarizer(enhanced, max_length=120, min_length=30, do_sample=False, truncation=True)[0]["summary_text"]
72
+ model_id = "t5-small"
73
 
74
+ elif task == "rewrite":
75
+ enhanced = f"You are a writing assistant. Rewrite this text for clarity and tone:\n{text}"
76
+ out = rewriter(enhanced, max_new_tokens=150, truncation=True)[0]["generated_text"]
77
+ model_id = "google/flan-t5-small"
78
 
79
+ elif task == "proofread":
80
+ enhanced = f"You are a grammar and style editor. Correct and improve this text:\n{text}"
81
+ out = proofreader(enhanced, max_new_tokens=150, truncation=True)[0]["generated_text"]
82
+ model_id = "google/flan-t5-small"
83
 
84
+ else:
85
+ enhanced = f"You are a programming tutor. Explain what this code does in simple language:\n{text}"
86
+ out = code_explainer(enhanced, max_new_tokens=200, truncation=True)[0]["generated_text"]
87
+ model_id = "Salesforce/codet5p-220m"
88
 
89
+ except Exception as e:
90
+ return {"error": f"inference_failed: {type(e).__name__}: {str(e)}"}
91
 
92
  latency = round((time.time() - start) * 1000, 2)
93
  return {
94
  "enhancedPrompt": enhanced,
95
+ "output": out,
96
+ "model": model_id,
97
  "latencyMs": latency
98
  }