nl2sql-bench / mini_server.py
ritvik360's picture
Upload folder using huggingface_hub
a39d8ef verified
raw
history blame
2.29 kB
import os
import torch
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from typing import List
from transformers import AutoModelForCausalLM, AutoTokenizer
# CRITICAL: GPU 0 pe host karenge
os.environ["CUDA_VISIBLE_DEVICES"] = "7"
app = FastAPI()
# Tera Merged Model Path
MODEL_PATH = "./qwen-7b-nl2sql-merged"
print("πŸš€ Loading Local Model for Inference API... (Takes a minute)")
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
device_map="auto",
torch_dtype=torch.bfloat16,
attn_implementation="sdpa" # Super stable, no vLLM crashes
)
print("βœ… Server Ready! Acting as OpenAI on Port 8000.")
# OpenAI Request Schemas
class Message(BaseModel):
role: str
content: str
class ChatRequest(BaseModel):
model: str
messages: List[Message]
temperature: float = 0.2
max_tokens: int = 512
@app.post("/v1/chat/completions")
async def chat(request: ChatRequest):
# Convert OpenAI messages to Qwen format
messages = [{"role": m.role, "content": m.content} for m in request.messages]
prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate SQL
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=request.max_tokens,
temperature=request.temperature,
do_sample=True if request.temperature > 0 else False,
pad_token_id=tokenizer.eos_token_id
)
# Decode only the newly generated text
response_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
# Return EXACT OpenAI JSON Structure
return {
"id": "chatcmpl-local-hackathon",
"object": "chat.completion",
"created": 1700000000,
"model": request.model,
"choices": [{
"index": 0,
"message": {"role": "assistant", "content": response_text},
"finish_reason": "stop"
}],
"usage": {"prompt_tokens": 0, "completion_tokens": 0, "total_tokens": 0}
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8001)