NotRev commited on
Commit
0d6e70e
·
verified ·
1 Parent(s): 3c59090

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +20 -22
src/streamlit_app.py CHANGED
@@ -1,24 +1,27 @@
1
  import json, re, ast, streamlit as st
2
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline, BitsAndBytesConfig
3
- import torch
4
 
5
- model_id = "mistralai/Mistral-7B-Instruct-v0.3"
6
-
7
- # 4-bit Quantization Configuration to reduce memory usage (VRAM/RAM)
8
- bnb_config = BitsAndBytesConfig(
9
- load_in_4bit=True,
10
- bnb_4bit_quant_type="nf4",
11
- bnb_4bit_compute_dtype=torch.float16
12
- )
13
 
14
  tok = AutoTokenizer.from_pretrained(model_id)
15
 
16
- # Load the model with the 4-bit quantization configuration
17
- model = AutoModelForCausalLM.from_pretrained(
18
- model_id,
19
- quantization_config=bnb_config,
20
- device_map="auto"
21
- )
 
 
 
 
 
 
 
 
 
22
 
23
  gen = pipeline("text-generation", model=model, tokenizer=tok,
24
  max_new_tokens=256, do_sample=False, return_full_text=False)
@@ -31,7 +34,6 @@ JSON:"""
31
  def extract(text: str):
32
  out = gen(prompt.format(text=text))
33
  raw = out[0].get("generated_text") or out[0].get("text") or str(out[0])
34
- # Relaxed regex to find JSON object anywhere in the output
35
  m = re.search(r"(\{[\s\S]*\})", raw)
36
  data = {}
37
 
@@ -40,7 +42,6 @@ def extract(text: str):
40
  for parser in (json.loads, ast.literal_eval):
41
  try:
42
  parsed_data = parser(blob)
43
- # Handle common case where model returns a list of dictionaries
44
  if isinstance(parsed_data, list) and parsed_data:
45
  data = parsed_data[0]
46
  elif isinstance(parsed_data, dict):
@@ -49,7 +50,6 @@ def extract(text: str):
49
  except Exception:
50
  continue
51
 
52
- # Error handling check: If parsing failed completely, return a structured error dictionary
53
  if not isinstance(data, dict):
54
  return {
55
  "SKILL": ["(Error: Invalid/Corrupted Model Output)"],
@@ -57,7 +57,6 @@ def extract(text: str):
57
  "DEBUG_RAW_OUTPUT": raw
58
  }
59
 
60
- # Successful return: Uses .get() which prevents KeyError even if keys are missing
61
  return {
62
  "SKILL": data.get("SKILL", []),
63
  "KNOWLEDGE": data.get("KNOWLEDGE", [])
@@ -67,5 +66,4 @@ st.title("Skill/Knowledge Extractor")
67
  text = st.text_area("Paste text")
68
 
69
  if st.button("Extract") and text.strip():
70
- st.json(extract(text))
71
-
 
1
  import json, re, ast, streamlit as st
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
3
+ import torch # Still needed for torch_dtype="auto" and device_map
4
 
5
+ # SWITCHED MODEL: From Mistral-7B to the much smaller Gemma-2B-Instruct
6
+ model_id = "google/gemma-2b-it"
 
 
 
 
 
 
7
 
8
  tok = AutoTokenizer.from_pretrained(model_id)
9
 
10
+ # Simplified Model Loading: Removed BitsAndBytesConfig
11
+ # This smaller model might load cleanly without 4-bit quantization, resolving the dependency issues.
12
+ try:
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ model_id,
15
+ torch_dtype=torch.bfloat16, # Use bfloat16 for better numerical stability if supported, otherwise auto
16
+ device_map="auto"
17
+ )
18
+ except Exception:
19
+ # Fallback to float16 if bfloat16 causes issues
20
+ model = AutoModelForCausalLM.from_pretrained(
21
+ model_id,
22
+ torch_dtype=torch.float16,
23
+ device_map="auto"
24
+ )
25
 
26
  gen = pipeline("text-generation", model=model, tokenizer=tok,
27
  max_new_tokens=256, do_sample=False, return_full_text=False)
 
34
  def extract(text: str):
35
  out = gen(prompt.format(text=text))
36
  raw = out[0].get("generated_text") or out[0].get("text") or str(out[0])
 
37
  m = re.search(r"(\{[\s\S]*\})", raw)
38
  data = {}
39
 
 
42
  for parser in (json.loads, ast.literal_eval):
43
  try:
44
  parsed_data = parser(blob)
 
45
  if isinstance(parsed_data, list) and parsed_data:
46
  data = parsed_data[0]
47
  elif isinstance(parsed_data, dict):
 
50
  except Exception:
51
  continue
52
 
 
53
  if not isinstance(data, dict):
54
  return {
55
  "SKILL": ["(Error: Invalid/Corrupted Model Output)"],
 
57
  "DEBUG_RAW_OUTPUT": raw
58
  }
59
 
 
60
  return {
61
  "SKILL": data.get("SKILL", []),
62
  "KNOWLEDGE": data.get("KNOWLEDGE", [])
 
66
  text = st.text_area("Paste text")
67
 
68
  if st.button("Extract") and text.strip():
69
+ st.json(extract(text))