Spaces:
Sleeping
Sleeping
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +20 -22
src/streamlit_app.py
CHANGED
|
@@ -1,24 +1,27 @@
|
|
| 1 |
import json, re, ast, streamlit as st
|
| 2 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 3 |
-
import torch
|
| 4 |
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
# 4-bit Quantization Configuration to reduce memory usage (VRAM/RAM)
|
| 8 |
-
bnb_config = BitsAndBytesConfig(
|
| 9 |
-
load_in_4bit=True,
|
| 10 |
-
bnb_4bit_quant_type="nf4",
|
| 11 |
-
bnb_4bit_compute_dtype=torch.float16
|
| 12 |
-
)
|
| 13 |
|
| 14 |
tok = AutoTokenizer.from_pretrained(model_id)
|
| 15 |
|
| 16 |
-
#
|
| 17 |
-
model
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
|
| 23 |
gen = pipeline("text-generation", model=model, tokenizer=tok,
|
| 24 |
max_new_tokens=256, do_sample=False, return_full_text=False)
|
|
@@ -31,7 +34,6 @@ JSON:"""
|
|
| 31 |
def extract(text: str):
|
| 32 |
out = gen(prompt.format(text=text))
|
| 33 |
raw = out[0].get("generated_text") or out[0].get("text") or str(out[0])
|
| 34 |
-
# Relaxed regex to find JSON object anywhere in the output
|
| 35 |
m = re.search(r"(\{[\s\S]*\})", raw)
|
| 36 |
data = {}
|
| 37 |
|
|
@@ -40,7 +42,6 @@ def extract(text: str):
|
|
| 40 |
for parser in (json.loads, ast.literal_eval):
|
| 41 |
try:
|
| 42 |
parsed_data = parser(blob)
|
| 43 |
-
# Handle common case where model returns a list of dictionaries
|
| 44 |
if isinstance(parsed_data, list) and parsed_data:
|
| 45 |
data = parsed_data[0]
|
| 46 |
elif isinstance(parsed_data, dict):
|
|
@@ -49,7 +50,6 @@ def extract(text: str):
|
|
| 49 |
except Exception:
|
| 50 |
continue
|
| 51 |
|
| 52 |
-
# Error handling check: If parsing failed completely, return a structured error dictionary
|
| 53 |
if not isinstance(data, dict):
|
| 54 |
return {
|
| 55 |
"SKILL": ["(Error: Invalid/Corrupted Model Output)"],
|
|
@@ -57,7 +57,6 @@ def extract(text: str):
|
|
| 57 |
"DEBUG_RAW_OUTPUT": raw
|
| 58 |
}
|
| 59 |
|
| 60 |
-
# Successful return: Uses .get() which prevents KeyError even if keys are missing
|
| 61 |
return {
|
| 62 |
"SKILL": data.get("SKILL", []),
|
| 63 |
"KNOWLEDGE": data.get("KNOWLEDGE", [])
|
|
@@ -67,5 +66,4 @@ st.title("Skill/Knowledge Extractor")
|
|
| 67 |
text = st.text_area("Paste text")
|
| 68 |
|
| 69 |
if st.button("Extract") and text.strip():
|
| 70 |
-
st.json(extract(text))
|
| 71 |
-
|
|
|
|
| 1 |
import json, re, ast, streamlit as st
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
|
| 3 |
+
import torch # Still needed for torch_dtype="auto" and device_map
|
| 4 |
|
| 5 |
+
# SWITCHED MODEL: From Mistral-7B to the much smaller Gemma-2B-Instruct
|
| 6 |
+
model_id = "google/gemma-2b-it"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
|
| 8 |
tok = AutoTokenizer.from_pretrained(model_id)
|
| 9 |
|
| 10 |
+
# Simplified Model Loading: Removed BitsAndBytesConfig
|
| 11 |
+
# This smaller model might load cleanly without 4-bit quantization, resolving the dependency issues.
|
| 12 |
+
try:
|
| 13 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 14 |
+
model_id,
|
| 15 |
+
torch_dtype=torch.bfloat16, # Use bfloat16 for better numerical stability if supported, otherwise auto
|
| 16 |
+
device_map="auto"
|
| 17 |
+
)
|
| 18 |
+
except Exception:
|
| 19 |
+
# Fallback to float16 if bfloat16 causes issues
|
| 20 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 21 |
+
model_id,
|
| 22 |
+
torch_dtype=torch.float16,
|
| 23 |
+
device_map="auto"
|
| 24 |
+
)
|
| 25 |
|
| 26 |
gen = pipeline("text-generation", model=model, tokenizer=tok,
|
| 27 |
max_new_tokens=256, do_sample=False, return_full_text=False)
|
|
|
|
| 34 |
def extract(text: str):
|
| 35 |
out = gen(prompt.format(text=text))
|
| 36 |
raw = out[0].get("generated_text") or out[0].get("text") or str(out[0])
|
|
|
|
| 37 |
m = re.search(r"(\{[\s\S]*\})", raw)
|
| 38 |
data = {}
|
| 39 |
|
|
|
|
| 42 |
for parser in (json.loads, ast.literal_eval):
|
| 43 |
try:
|
| 44 |
parsed_data = parser(blob)
|
|
|
|
| 45 |
if isinstance(parsed_data, list) and parsed_data:
|
| 46 |
data = parsed_data[0]
|
| 47 |
elif isinstance(parsed_data, dict):
|
|
|
|
| 50 |
except Exception:
|
| 51 |
continue
|
| 52 |
|
|
|
|
| 53 |
if not isinstance(data, dict):
|
| 54 |
return {
|
| 55 |
"SKILL": ["(Error: Invalid/Corrupted Model Output)"],
|
|
|
|
| 57 |
"DEBUG_RAW_OUTPUT": raw
|
| 58 |
}
|
| 59 |
|
|
|
|
| 60 |
return {
|
| 61 |
"SKILL": data.get("SKILL", []),
|
| 62 |
"KNOWLEDGE": data.get("KNOWLEDGE", [])
|
|
|
|
| 66 |
text = st.text_area("Paste text")
|
| 67 |
|
| 68 |
if st.button("Extract") and text.strip():
|
| 69 |
+
st.json(extract(text))
|
|
|