vijkid001 commited on
Commit
36a3527
·
verified ·
1 Parent(s): 7646e52

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers'
3
+
4
+ from fastapi import FastAPI, Header, HTTPException
5
+ from fastapi.middleware.cors import CORSMiddleware
6
+ from pydantic import BaseModel, Field
7
+ from transformers import AutoTokenizer, AutoModelForCausalLM
8
+ from typing import Optional, Dict, Annotated
9
+ import logging
10
+ import torch
11
+ import os
12
+
13
+ # Initialize logging
14
+ logging.basicConfig(level=logging.INFO)
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Load model
18
+ MODEL_NAME = "defog/sqlcoder-7b-2"
19
+ logger.info(f"Loading model: {MODEL_NAME}")
20
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ MODEL_NAME,
23
+ device_map="auto",
24
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
25
+ )
26
+
27
+ # FastAPI init
28
+ app = FastAPI(title="Text to SQL API")
29
+
30
+ # CORS for Hugging Face Space
31
+ app.add_middleware(
32
+ CORSMiddleware,
33
+ allow_origins=["*"],
34
+ allow_methods=["*"],
35
+ allow_headers=["*"],
36
+ )
37
+
38
+ # Request Model
39
+ class RequestModel(BaseModel):
40
+ entity_urn: str
41
+ prompt: str
42
+
43
+ # Response Model
44
+ class ResponseModel(BaseModel):
45
+ message: str
46
+ result: str
47
+ action_type: str
48
+ entity_urn: str
49
+ metadata: Optional[Dict] = None
50
+
51
+ @app.get("/")
52
+ async def root():
53
+ return {
54
+ "message": "Text-to-SQL API running",
55
+ "docs": "/docs",
56
+ "health": "/health"
57
+ }
58
+
59
+ @app.get("/health")
60
+ async def health():
61
+ return {"status": "healthy"}
62
+
63
+ @app.post("/generate", response_model=ResponseModel)
64
+ async def generate_sql(
65
+ request: RequestModel,
66
+ x_api_key: Annotated[str, Header()] # Optional token check
67
+ ):
68
+ try:
69
+ if not request.prompt.strip():
70
+ return ResponseModel(
71
+ message="failure",
72
+ result="Empty prompt",
73
+ action_type="text_to_sql",
74
+ entity_urn=request.entity_urn
75
+ )
76
+
77
+ inputs = tokenizer(request.prompt, return_tensors="pt").to(model.device)
78
+ outputs = model.generate(**inputs, max_length=512)
79
+ sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
+
81
+ return ResponseModel(
82
+ message="success",
83
+ result=sql.strip(),
84
+ action_type="text_to_sql",
85
+ entity_urn=request.entity_urn,
86
+ metadata={"tokens": len(inputs["input_ids"][0])}
87
+ )
88
+
89
+ except Exception as e:
90
+ logger.error(f"Error: {str(e)}")
91
+ return ResponseModel(
92
+ message="failure",
93
+ result=f"Error: {str(e)}",
94
+ action_type="text_to_sql",
95
+ entity_urn=request.entity_urn
96
+ )