varshithkumar's picture
Added app.py and requirements.txt
19c6b1f
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
import os
app = FastAPI()
HF_TOKEN = os.getenv("HF_TOKEN")
print("HF_TOKEN loaded:", HF_TOKEN is not None)
BASE_MODEL = "google/gemma-2b-it"
LORA_MODEL = "varshithkumar/gemma-finetuned-sql"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
print("Loading base model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # Use 4-bit
bnb_4bit_compute_dtype=torch.float16, # Compute in float16
bnb_4bit_use_double_quant=True # Optional, better accuracy
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
use_auth_token=HF_TOKEN
)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
use_fast=True,
use_auth_token=HF_TOKEN
)
print("Applying LoRA adapter...")
model = PeftModel.from_pretrained(
base_model,
LORA_MODEL,
use_auth_token=HF_TOKEN,
device_map="auto" # ensure LoRA is loaded on the right device
)
model.to(device)
print("Model loaded successfully!")
class InputData(BaseModel):
prompt: str
max_length: int = 256 # default max length if not provided
@app.post("/generate")
def generate_text(data: InputData):
inputs = tokenizer(data.prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=data.max_length)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": text}