vivekutty / app.py
Vivek16's picture
Update app.py
3028027 verified
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
app = FastAPI()
# ------------------- Model Setup -------------------
MODEL_PATH = "vivekutty" # relative path to your uploaded model folder
MAX_SEQ_LENGTH = 2048
LOAD_IN_4BIT = True # True if your model is 4-bit LoRA
# Configure 4-bit quantization if needed
bnb_config = BitsAndBytesConfig(
load_in_4bit=LOAD_IN_4BIT,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
)
# Load tokenizer and model from local folder
tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_PATH,
quantization_config=bnb_config if LOAD_IN_4BIT else None,
device_map="auto",
local_files_only=True
)
# ------------------- Inference Helper -------------------
def generate_response(instruction: str, input_text: str, max_new_tokens: int = 128):
chat_prompt = f"""### Instruction:
{instruction}
### Input:
{input_text}
### Response:
"""
# Tokenize input
inputs = tokenizer(chat_prompt, return_tensors="pt").to(model.device)
# Generate output
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_new_tokens,
do_sample=True,
temperature=0.7
)
# Decode output
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
return response
# ------------------- API Input Schema -------------------
class ChatRequest(BaseModel):
instruction: str = ""
input_text: str
# ------------------- API Endpoints -------------------
@app.post("/chat")
async def chat(req: ChatRequest):
response = generate_response(req.instruction, req.input_text)
return {"response": response}
@app.get("/")
async def root():
return {"message": "Model API is running!"}