varshithkumar commited on
Commit
19c6b1f
·
1 Parent(s): c616e72

Added app.py and requirements.txt

Browse files
Files changed (1) hide show
  1. app.py +11 -9
app.py CHANGED
@@ -1,7 +1,6 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModelForCausalLM
4
- from transformers import BitsAndBytesConfig
5
  from peft import PeftModel
6
  import torch
7
  import os
@@ -17,11 +16,12 @@ LORA_MODEL = "varshithkumar/gemma-finetuned-sql"
17
  device = "cuda" if torch.cuda.is_available() else "cpu"
18
  print("Using device:", device)
19
 
20
- print("Loading base model...")
21
 
22
  bnb_config = BitsAndBytesConfig(
23
- load_in_8bit=True,
24
- llm_int8_enable_fp32_cpu_offload=True
 
25
  )
26
 
27
  base_model = AutoModelForCausalLM.from_pretrained(
@@ -31,25 +31,27 @@ base_model = AutoModelForCausalLM.from_pretrained(
31
  use_auth_token=HF_TOKEN
32
  )
33
 
34
-
35
  print("Loading tokenizer...")
36
  tokenizer = AutoTokenizer.from_pretrained(
37
  BASE_MODEL,
38
  use_fast=True,
39
- token=HF_TOKEN
40
  )
41
 
42
  print("Applying LoRA adapter...")
43
  model = PeftModel.from_pretrained(
44
  base_model,
45
  LORA_MODEL,
46
- token=HF_TOKEN
 
47
  )
48
 
 
49
  print("Model loaded successfully!")
50
 
51
  class InputData(BaseModel):
52
  prompt: str
 
53
 
54
  @app.post("/generate")
55
  def generate_text(data: InputData):
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
 
4
  from peft import PeftModel
5
  import torch
6
  import os
 
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
  print("Using device:", device)
18
 
19
+ print("Loading base model with 4-bit quantization...")
20
 
21
  bnb_config = BitsAndBytesConfig(
22
+ load_in_4bit=True, # Use 4-bit
23
+ bnb_4bit_compute_dtype=torch.float16, # Compute in float16
24
+ bnb_4bit_use_double_quant=True # Optional, better accuracy
25
  )
26
 
27
  base_model = AutoModelForCausalLM.from_pretrained(
 
31
  use_auth_token=HF_TOKEN
32
  )
33
 
 
34
  print("Loading tokenizer...")
35
  tokenizer = AutoTokenizer.from_pretrained(
36
  BASE_MODEL,
37
  use_fast=True,
38
+ use_auth_token=HF_TOKEN
39
  )
40
 
41
  print("Applying LoRA adapter...")
42
  model = PeftModel.from_pretrained(
43
  base_model,
44
  LORA_MODEL,
45
+ use_auth_token=HF_TOKEN,
46
+ device_map="auto" # ensure LoRA is loaded on the right device
47
  )
48
 
49
+ model.to(device)
50
  print("Model loaded successfully!")
51
 
52
  class InputData(BaseModel):
53
  prompt: str
54
+ max_length: int = 256 # default max length if not provided
55
 
56
  @app.post("/generate")
57
  def generate_text(data: InputData):