NotRev commited on
Commit
9fbf203
·
verified ·
1 Parent(s): 2f6e556

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +24 -4
src/streamlit_app.py CHANGED
@@ -1,9 +1,25 @@
1
  import json, re, ast, streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
 
3
 
4
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
 
 
 
 
 
 
 
 
5
  tok = AutoTokenizer.from_pretrained(model_id)
6
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype="auto", device_map="auto")
 
 
 
 
 
 
 
7
  gen = pipeline("text-generation", model=model, tokenizer=tok,
8
  max_new_tokens=256, do_sample=False, return_full_text=False)
9
 
@@ -15,13 +31,16 @@ JSON:"""
15
  def extract(text: str):
16
  out = gen(prompt.format(text=text))
17
  raw = out[0].get("generated_text") or out[0].get("text") or str(out[0])
 
18
  m = re.search(r"(\{[\s\S]*\})", raw)
19
  data = {}
 
20
  if m:
21
  blob = m.group(0).strip()
22
  for parser in (json.loads, ast.literal_eval):
23
  try:
24
  parsed_data = parser(blob)
 
25
  if isinstance(parsed_data, list) and parsed_data:
26
  data = parsed_data[0]
27
  elif isinstance(parsed_data, dict):
@@ -30,6 +49,7 @@ def extract(text: str):
30
  except Exception:
31
  continue
32
 
 
33
  if not isinstance(data, dict):
34
  return {
35
  "SKILL": ["(Error: Invalid/Corrupted Model Output)"],
@@ -37,6 +57,7 @@ def extract(text: str):
37
  "DEBUG_RAW_OUTPUT": raw
38
  }
39
 
 
40
  return {
41
  "SKILL": data.get("SKILL", []),
42
  "KNOWLEDGE": data.get("KNOWLEDGE", [])
@@ -46,5 +67,4 @@ st.title("Skill/Knowledge Extractor")
46
  text = st.text_area("Paste text")
47
 
48
  if st.button("Extract") and text.strip():
49
- st.json(extract(text))
50
-
 
1
  import json, re, ast, streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
3
+ import torch
4
 
5
  model_id = "mistralai/Mistral-7B-Instruct-v0.3"
6
+
7
+ # 4-bit Quantization Configuration to reduce memory usage (VRAM/RAM)
8
+ bnb_config = BitsAndBytesConfig(
9
+ load_in_4bit=True,
10
+ bnb_4bit_quant_type="nf4",
11
+ bnb_4bit_compute_dtype=torch.float16
12
+ )
13
+
14
  tok = AutoTokenizer.from_pretrained(model_id)
15
+
16
+ # Load the model with the 4-bit quantization configuration
17
+ model = AutoModelForCausalLM.from_pretrained(
18
+ model_id,
19
+ quantization_config=bnb_config,
20
+ device_map="auto"
21
+ )
22
+
23
  gen = pipeline("text-generation", model=model, tokenizer=tok,
24
  max_new_tokens=256, do_sample=False, return_full_text=False)
25
 
 
31
  def extract(text: str):
32
  out = gen(prompt.format(text=text))
33
  raw = out[0].get("generated_text") or out[0].get("text") or str(out[0])
34
+ # Relaxed regex to find JSON object anywhere in the output
35
  m = re.search(r"(\{[\s\S]*\})", raw)
36
  data = {}
37
+
38
  if m:
39
  blob = m.group(0).strip()
40
  for parser in (json.loads, ast.literal_eval):
41
  try:
42
  parsed_data = parser(blob)
43
+ # Handle common case where model returns a list of dictionaries
44
  if isinstance(parsed_data, list) and parsed_data:
45
  data = parsed_data[0]
46
  elif isinstance(parsed_data, dict):
 
49
  except Exception:
50
  continue
51
 
52
+ # Error handling check: If parsing failed completely, return a structured error dictionary
53
  if not isinstance(data, dict):
54
  return {
55
  "SKILL": ["(Error: Invalid/Corrupted Model Output)"],
 
57
  "DEBUG_RAW_OUTPUT": raw
58
  }
59
 
60
+ # Successful return: Uses .get() which prevents KeyError even if keys are missing
61
  return {
62
  "SKILL": data.get("SKILL", []),
63
  "KNOWLEDGE": data.get("KNOWLEDGE", [])
 
67
  text = st.text_area("Paste text")
68
 
69
  if st.button("Extract") and text.strip():
70
+ st.json(extract(text))