File size: 1,688 Bytes
19c6b1f
fa90325
19c6b1f
63d58a1
 
 
91dc642
fa90325
 
8a85f65
b4a183a
 
63d58a1
 
 
8a85f65
b4a183a
8a85f65
19c6b1f
c616e72
 
19c6b1f
 
 
c616e72
 
fa90325
 
c616e72
92e762c
 
 
8a85f65
 
fa90325
 
 
19c6b1f
fa90325
 
 
 
 
 
19c6b1f
 
fa90325
 
19c6b1f
fa90325
63d58a1
fa90325
 
19c6b1f
63d58a1
fa90325
 
8a85f65
fa90325
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from fastapi import FastAPI 
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from peft import PeftModel
import torch
import os

app = FastAPI()

HF_TOKEN = os.getenv("HF_TOKEN")
print("HF_TOKEN loaded:", HF_TOKEN is not None)

BASE_MODEL = "google/gemma-2b-it"
LORA_MODEL = "varshithkumar/gemma-finetuned-sql"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

print("Loading base model with 4-bit quantization...")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,                  # Use 4-bit
    bnb_4bit_compute_dtype=torch.float16,  # Compute in float16
    bnb_4bit_use_double_quant=True      # Optional, better accuracy
)

base_model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    use_auth_token=HF_TOKEN
)

print("Loading tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(
    BASE_MODEL,
    use_fast=True,
    use_auth_token=HF_TOKEN
)

print("Applying LoRA adapter...")
model = PeftModel.from_pretrained(
    base_model,
    LORA_MODEL,
    use_auth_token=HF_TOKEN,
    device_map="auto"  # ensure LoRA is loaded on the right device
)

model.to(device)
print("Model loaded successfully!")

class InputData(BaseModel):
    prompt: str
    max_length: int = 256  # default max length if not provided

@app.post("/generate")
def generate_text(data: InputData):
    inputs = tokenizer(data.prompt, return_tensors="pt").to(device)
    outputs = model.generate(**inputs, max_length=data.max_length)
    text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return {"response": text}