Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -5,6 +5,7 @@ import os, re, json, time, sys, csv, uuid, datetime
|
|
| 5 |
from typing import List, Dict, Any, Optional
|
| 6 |
from functools import lru_cache
|
| 7 |
from xml.etree import ElementTree as ET
|
|
|
|
| 8 |
|
| 9 |
import numpy as np
|
| 10 |
import requests
|
|
@@ -15,6 +16,7 @@ ASSETS_DIR = os.environ.get("ASSETS_DIR", "assets")
|
|
| 15 |
FAISS_PATH = os.environ.get("FAISS_PATH", f"{ASSETS_DIR}/index.faiss")
|
| 16 |
META_PATH = os.environ.get("META_PATH", f"{ASSETS_DIR}/index_meta.filtered.jsonl")
|
| 17 |
REL_CONFIG_PATH = os.environ.get("REL_CONFIG_PATH", f"{ASSETS_DIR}/relevance_config.json")
|
|
|
|
| 18 |
|
| 19 |
# Models
|
| 20 |
BASE_MODEL = os.environ.get("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
|
|
@@ -189,17 +191,35 @@ if HF_READ_TOKEN:
|
|
| 189 |
if ADAPTER_REPO:
|
| 190 |
ADAPTER_PATH = snapshot_download(repo_id=ADAPTER_REPO, allow_patterns=["*"])
|
| 191 |
|
|
|
|
| 192 |
dlog("LLM", f"Loading base model: {BASE_MODEL}")
|
| 193 |
tokenizer_lm = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
dlog("LLM", f"Loading LoRA adapter from: {ADAPTER_PATH}")
|
| 199 |
-
model_lm
|
| 200 |
-
model_lm.to(device)
|
| 201 |
model_lm.eval()
|
| 202 |
|
|
|
|
| 203 |
GEN_ARGS_GROUNDED = dict(
|
| 204 |
max_new_tokens=MAX_NEW_TOKENS_GROUNDED,
|
| 205 |
do_sample=False,
|
|
@@ -901,5 +921,8 @@ with gr.Blocks(theme="soft") as demo:
|
|
| 901 |
outputs=[fb_status, feedback_grp],
|
| 902 |
)
|
| 903 |
|
| 904 |
-
demo.queue(
|
|
|
|
|
|
|
|
|
|
| 905 |
|
|
|
|
| 5 |
from typing import List, Dict, Any, Optional
|
| 6 |
from functools import lru_cache
|
| 7 |
from xml.etree import ElementTree as ET
|
| 8 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
|
| 9 |
|
| 10 |
import numpy as np
|
| 11 |
import requests
|
|
|
|
| 16 |
FAISS_PATH = os.environ.get("FAISS_PATH", f"{ASSETS_DIR}/index.faiss")
|
| 17 |
META_PATH = os.environ.get("META_PATH", f"{ASSETS_DIR}/index_meta.filtered.jsonl")
|
| 18 |
REL_CONFIG_PATH = os.environ.get("REL_CONFIG_PATH", f"{ASSETS_DIR}/relevance_config.json")
|
| 19 |
+
QUANTIZE = os.environ.get("QUANTIZE", "4bit") # "none" | "8bit" | "4bit"
|
| 20 |
|
| 21 |
# Models
|
| 22 |
BASE_MODEL = os.environ.get("BASE_MODEL", "mistralai/Mistral-7B-Instruct-v0.2")
|
|
|
|
| 191 |
if ADAPTER_REPO:
|
| 192 |
ADAPTER_PATH = snapshot_download(repo_id=ADAPTER_REPO, allow_patterns=["*"])
|
| 193 |
|
| 194 |
+
# --- LLM load (quantized optional) ---
|
| 195 |
dlog("LLM", f"Loading base model: {BASE_MODEL}")
|
| 196 |
tokenizer_lm = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
|
| 197 |
+
|
| 198 |
+
if QUANTIZE in {"8bit", "4bit"}:
|
| 199 |
+
bnb_config = BitsAndBytesConfig(
|
| 200 |
+
load_in_8bit=(QUANTIZE == "8bit"),
|
| 201 |
+
load_in_4bit=(QUANTIZE == "4bit"),
|
| 202 |
+
bnb_4bit_quant_type="nf4",
|
| 203 |
+
bnb_4bit_use_double_quant=True,
|
| 204 |
+
bnb_4bit_compute_dtype=torch.float16,
|
| 205 |
+
)
|
| 206 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 207 |
+
BASE_MODEL,
|
| 208 |
+
device_map="auto",
|
| 209 |
+
quantization_config=bnb_config,
|
| 210 |
+
)
|
| 211 |
+
else:
|
| 212 |
+
base_model = AutoModelForCausalLM.from_pretrained(
|
| 213 |
+
BASE_MODEL,
|
| 214 |
+
torch_dtype=dtype,
|
| 215 |
+
device_map="auto",
|
| 216 |
+
)
|
| 217 |
|
| 218 |
dlog("LLM", f"Loading LoRA adapter from: {ADAPTER_PATH}")
|
| 219 |
+
model_lm = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
|
|
|
|
| 220 |
model_lm.eval()
|
| 221 |
|
| 222 |
+
|
| 223 |
GEN_ARGS_GROUNDED = dict(
|
| 224 |
max_new_tokens=MAX_NEW_TOKENS_GROUNDED,
|
| 225 |
do_sample=False,
|
|
|
|
| 921 |
outputs=[fb_status, feedback_grp],
|
| 922 |
)
|
| 923 |
|
| 924 |
+
demo.queue(
|
| 925 |
+
concurrency_count=int(os.environ.get("CONCURRENCY", "2")),
|
| 926 |
+
max_size=64
|
| 927 |
+
).launch()
|
| 928 |
|