Slaiwala commited on
Commit
d940479
·
verified ·
1 Parent(s): 964071f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -6
app.py CHANGED
@@ -5,6 +5,7 @@ import os, re, json, time, sys, csv, uuid, datetime
5
  from typing import List, Dict, Any, Optional
6
  from functools import lru_cache
7
  from xml.etree import ElementTree as ET
 
8
 
9
  import numpy as np
10
  import requests
@@ -15,6 +16,7 @@ ASSETS_DIR = os.environ.get("ASSETS_DIR", "assets")
15
  FAISS_PATH = os.environ.get("FAISS_PATH", f"{ASSETS_DIR}/index.faiss")
16
  META_PATH = os.environ.get("META_PATH", f"{ASSETS_DIR}/index_meta.filtered.jsonl")
17
  REL_CONFIG_PATH = os.environ.get("REL_CONFIG_PATH", f"{ASSETS_DIR}/relevance_config.json")
 
18
 
19
  # Models
20
  BASE_MODEL = os.environ.get("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
@@ -189,17 +191,35 @@ if HF_READ_TOKEN:
189
  if ADAPTER_REPO:
190
  ADAPTER_PATH = snapshot_download(repo_id=ADAPTER_REPO, allow_patterns=["*"])
191
 
 
192
  dlog("LLM", f"Loading base model: {BASE_MODEL}")
193
  tokenizer_lm = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
194
- base_model = AutoModelForCausalLM.from_pretrained(
195
- BASE_MODEL, torch_dtype=dtype, device_map="auto"
196
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
197
 
198
  dlog("LLM", f"Loading LoRA adapter from: {ADAPTER_PATH}")
199
- model_lm = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
200
- model_lm.to(device)
201
  model_lm.eval()
202
 
 
203
  GEN_ARGS_GROUNDED = dict(
204
  max_new_tokens=MAX_NEW_TOKENS_GROUNDED,
205
  do_sample=False,
@@ -901,5 +921,8 @@ with gr.Blocks(theme="soft") as demo:
901
  outputs=[fb_status, feedback_grp],
902
  )
903
 
904
- demo.queue(max_size=32).launch()
 
 
 
905
 
 
5
  from typing import List, Dict, Any, Optional
6
  from functools import lru_cache
7
  from xml.etree import ElementTree as ET
8
+ from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
9
 
10
  import numpy as np
11
  import requests
 
16
  FAISS_PATH = os.environ.get("FAISS_PATH", f"{ASSETS_DIR}/index.faiss")
17
  META_PATH = os.environ.get("META_PATH", f"{ASSETS_DIR}/index_meta.filtered.jsonl")
18
  REL_CONFIG_PATH = os.environ.get("REL_CONFIG_PATH", f"{ASSETS_DIR}/relevance_config.json")
19
+ QUANTIZE = os.environ.get("QUANTIZE", "4bit") # "none" | "8bit" | "4bit"
20
 
21
  # Models
22
  BASE_MODEL = os.environ.get("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
 
191
  if ADAPTER_REPO:
192
  ADAPTER_PATH = snapshot_download(repo_id=ADAPTER_REPO, allow_patterns=["*"])
193
 
194
+ # --- LLM load (quantized optional) ---
195
  dlog("LLM", f"Loading base model: {BASE_MODEL}")
196
  tokenizer_lm = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
197
+
198
+ if QUANTIZE in {"8bit", "4bit"}:
199
+ bnb_config = BitsAndBytesConfig(
200
+ load_in_8bit=(QUANTIZE == "8bit"),
201
+ load_in_4bit=(QUANTIZE == "4bit"),
202
+ bnb_4bit_quant_type="nf4",
203
+ bnb_4bit_use_double_quant=True,
204
+ bnb_4bit_compute_dtype=torch.float16,
205
+ )
206
+ base_model = AutoModelForCausalLM.from_pretrained(
207
+ BASE_MODEL,
208
+ device_map="auto",
209
+ quantization_config=bnb_config,
210
+ )
211
+ else:
212
+ base_model = AutoModelForCausalLM.from_pretrained(
213
+ BASE_MODEL,
214
+ torch_dtype=dtype,
215
+ device_map="auto",
216
+ )
217
 
218
  dlog("LLM", f"Loading LoRA adapter from: {ADAPTER_PATH}")
219
+ model_lm = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
 
220
  model_lm.eval()
221
 
222
+
223
  GEN_ARGS_GROUNDED = dict(
224
  max_new_tokens=MAX_NEW_TOKENS_GROUNDED,
225
  do_sample=False,
 
921
  outputs=[fb_status, feedback_grp],
922
  )
923
 
924
+ demo.queue(
925
+ concurrency_count=int(os.environ.get("CONCURRENCY", "2")),
926
+ max_size=64
927
+ ).launch()
928