vijkid001 commited on
Commit
13c7baf
·
verified ·
1 Parent(s): 36a3527

Delete app.py

Browse files
Files changed (1) hide show
  1. app.py +0 -96
app.py DELETED
@@ -1,96 +0,0 @@
1
- import os
2
- os.environ['TRANSFORMERS_CACHE'] = '/tmp/transformers'
3
-
4
- from fastapi import FastAPI, Header, HTTPException
5
- from fastapi.middleware.cors import CORSMiddleware
6
- from pydantic import BaseModel, Field
7
- from transformers import AutoTokenizer, AutoModelForCausalLM
8
- from typing import Optional, Dict, Annotated
9
- import logging
10
- import torch
11
- import os
12
-
13
- # Initialize logging
14
- logging.basicConfig(level=logging.INFO)
15
- logger = logging.getLogger(__name__)
16
-
17
- # Load model
18
- MODEL_NAME = "defog/sqlcoder-7b-2"
19
- logger.info(f"Loading model: {MODEL_NAME}")
20
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
21
- model = AutoModelForCausalLM.from_pretrained(
22
- MODEL_NAME,
23
- device_map="auto",
24
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
25
- )
26
-
27
- # FastAPI init
28
- app = FastAPI(title="Text to SQL API")
29
-
30
- # CORS for Hugging Face Space
31
- app.add_middleware(
32
- CORSMiddleware,
33
- allow_origins=["*"],
34
- allow_methods=["*"],
35
- allow_headers=["*"],
36
- )
37
-
38
- # Request Model
39
- class RequestModel(BaseModel):
40
- entity_urn: str
41
- prompt: str
42
-
43
- # Response Model
44
- class ResponseModel(BaseModel):
45
- message: str
46
- result: str
47
- action_type: str
48
- entity_urn: str
49
- metadata: Optional[Dict] = None
50
-
51
- @app.get("/")
52
- async def root():
53
- return {
54
- "message": "Text-to-SQL API running",
55
- "docs": "/docs",
56
- "health": "/health"
57
- }
58
-
59
- @app.get("/health")
60
- async def health():
61
- return {"status": "healthy"}
62
-
63
- @app.post("/generate", response_model=ResponseModel)
64
- async def generate_sql(
65
- request: RequestModel,
66
- x_api_key: Annotated[str, Header()] # Optional token check
67
- ):
68
- try:
69
- if not request.prompt.strip():
70
- return ResponseModel(
71
- message="failure",
72
- result="Empty prompt",
73
- action_type="text_to_sql",
74
- entity_urn=request.entity_urn
75
- )
76
-
77
- inputs = tokenizer(request.prompt, return_tensors="pt").to(model.device)
78
- outputs = model.generate(**inputs, max_length=512)
79
- sql = tokenizer.decode(outputs[0], skip_special_tokens=True)
80
-
81
- return ResponseModel(
82
- message="success",
83
- result=sql.strip(),
84
- action_type="text_to_sql",
85
- entity_urn=request.entity_urn,
86
- metadata={"tokens": len(inputs["input_ids"][0])}
87
- )
88
-
89
- except Exception as e:
90
- logger.error(f"Error: {str(e)}")
91
- return ResponseModel(
92
- message="failure",
93
- result=f"Error: {str(e)}",
94
- action_type="text_to_sql",
95
- entity_urn=request.entity_urn
96
- )