Spaces:
Runtime error
Runtime error
File size: 1,688 Bytes
19c6b1f fa90325 19c6b1f 63d58a1 91dc642 fa90325 8a85f65 b4a183a 63d58a1 8a85f65 b4a183a 8a85f65 19c6b1f c616e72 19c6b1f c616e72 fa90325 c616e72 92e762c 8a85f65 fa90325 19c6b1f fa90325 19c6b1f fa90325 19c6b1f fa90325 63d58a1 fa90325 19c6b1f 63d58a1 fa90325 8a85f65 fa90325 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
import os
app = FastAPI()
HF_TOKEN = os.getenv("HF_TOKEN")
print("HF_TOKEN loaded:", HF_TOKEN is not None)
BASE_MODEL = "google/gemma-2b-it"
LORA_MODEL = "varshithkumar/gemma-finetuned-sql"
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)
print("Loading base model with 4-bit quantization...")
bnb_config = BitsAndBytesConfig(
load_in_4bit=True, # Use 4-bit
bnb_4bit_compute_dtype=torch.float16, # Compute in float16
bnb_4bit_use_double_quant=True # Optional, better accuracy
)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
quantization_config=bnb_config,
device_map="auto",
use_auth_token=HF_TOKEN
)
print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
BASE_MODEL,
use_fast=True,
use_auth_token=HF_TOKEN
)
print("Applying LoRA adapter...")
model = PeftModel.from_pretrained(
base_model,
LORA_MODEL,
use_auth_token=HF_TOKEN,
device_map="auto" # ensure LoRA is loaded on the right device
)
model.to(device)
print("Model loaded successfully!")
class InputData(BaseModel):
prompt: str
max_length: int = 256 # default max length if not provided
@app.post("/generate")
def generate_text(data: InputData):
inputs = tokenizer(data.prompt, return_tensors="pt").to(device)
outputs = model.generate(**inputs, max_length=data.max_length)
text = tokenizer.decode(outputs[0], skip_special_tokens=True)
return {"response": text}
|