16pramodh commited on
Commit
33abc82
·
1 Parent(s): ebf3750

fixing cache issue

Browse files
Files changed (1) hide show
  1. app.py +6 -11
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import os
 
 
2
  from fastapi import FastAPI
3
  from pydantic import BaseModel
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
5
  import uvicorn
6
 
7
- # Set writable cache directory
8
- os.environ["TRANSFORMERS_CACHE"] = "/tmp/cache"
9
-
10
  MODEL_NAME = "16pramodh/t2s_model"
11
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
@@ -18,13 +17,9 @@ class QueryRequest(BaseModel):
18
 
19
  @app.post("/predict")
20
  def predict(request: QueryRequest):
21
- try:
22
- inputs = tokenizer(request.text, return_tensors="pt")
23
- outputs = model.generate(**inputs, max_length=256)
24
- sql_query = tokenizer.decode(outputs[0], skip_special_tokens=True)
25
- return {"sql": sql_query}
26
- except Exception as e:
27
- return {"error": str(e)}
28
 
29
  if __name__ == "__main__":
30
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
  import os
2
+ os.environ["TRANSFORMERS_CACHE"] = "/tmp/cache" # MUST be before HF imports
3
+
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
7
  import uvicorn
8
 
 
 
 
9
  MODEL_NAME = "16pramodh/t2s_model"
10
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
11
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
 
17
 
18
  @app.post("/predict")
19
  def predict(request: QueryRequest):
20
+ inputs = tokenizer(request.text, return_tensors="pt")
21
+ outputs = model.generate(**inputs, max_length=256)
22
+ return {"sql": tokenizer.decode(outputs[0], skip_special_tokens=True)}
 
 
 
 
23
 
24
  if __name__ == "__main__":
25
+ uvicorn.run(app, host="0.0.0.0", port=7860)