from fastapi import FastAPI from pydantic import BaseModel from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig from peft import PeftModel import torch import os app = FastAPI() HF_TOKEN = os.getenv("HF_TOKEN") print("HF_TOKEN loaded:", HF_TOKEN is not None) BASE_MODEL = "google/gemma-2b-it" LORA_MODEL = "varshithkumar/gemma-finetuned-sql" device = "cuda" if torch.cuda.is_available() else "cpu" print("Using device:", device) print("Loading base model with 4-bit quantization...") bnb_config = BitsAndBytesConfig( load_in_4bit=True, # Use 4-bit bnb_4bit_compute_dtype=torch.float16, # Compute in float16 bnb_4bit_use_double_quant=True # Optional, better accuracy ) base_model = AutoModelForCausalLM.from_pretrained( BASE_MODEL, quantization_config=bnb_config, device_map="auto", use_auth_token=HF_TOKEN ) print("Loading tokenizer...") tokenizer = AutoTokenizer.from_pretrained( BASE_MODEL, use_fast=True, use_auth_token=HF_TOKEN ) print("Applying LoRA adapter...") model = PeftModel.from_pretrained( base_model, LORA_MODEL, use_auth_token=HF_TOKEN, device_map="auto" # ensure LoRA is loaded on the right device ) model.to(device) print("Model loaded successfully!") class InputData(BaseModel): prompt: str max_length: int = 256 # default max length if not provided @app.post("/generate") def generate_text(data: InputData): inputs = tokenizer(data.prompt, return_tensors="pt").to(device) outputs = model.generate(**inputs, max_length=data.max_length) text = tokenizer.decode(outputs[0], skip_special_tokens=True) return {"response": text}