ThesisPlease / src /streamlit_app.py
NotRev's picture
Update src/streamlit_app.py
cffdf8b verified
import json, re, ast, streamlit as st
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
import torch
import os
# NEW MODEL: Phi-2 - Does NOT use sentencepiece
model_id = "microsoft/phi-2"
# Token is NOT needed for Phi-2
# HF_TOKEN = os.environ.get("HF_TOKEN") # Removed
tok = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
# Model loading remains the same
try:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
trust_remote_code=True
# token=HF_TOKEN # Removed
)
except Exception:
model = AutoModelForCausalLM.from_pretrained(
model_id,
torch_dtype=torch.float16,
device_map="auto",
trust_remote_code=True
# token=HF_TOKEN # Removed
)
# ... rest of the pipeline and extraction code is the same ...
gen = pipeline("text-generation", model=model, tokenizer=tok,
max_new_tokens=256, do_sample=False, return_full_text=False)
prompt = """Extract skills and knowledge from the text.
Return JSON: {"SKILL":[...], "KNOWLEDGE":[...]}.
Text: {text}
JSON:"""
def extract(text: str):
out = gen(prompt.format(text=text))
raw = out[0].get("generated_text") or out[0].get("text") or str(out[0])
m = re.search(r"(\{[\s\S]*\})", raw)
data = {}
if m:
blob = m.group(0).strip()
for parser in (json.loads, ast.literal_eval):
try:
parsed_data = parser(blob)
if isinstance(parsed_data, list) and parsed_data:
data = parsed_data[0]
elif isinstance(parsed_data, dict):
data = parsed_data
break
except Exception:
continue
if not isinstance(data, dict):
# NOTE: You are now hitting a KeyError: "SKILL" (image_36e619.png).
# This is because the model returned bad JSON. This is the code that handles it:
return {
"SKILL": ["(Error: Invalid/Corrupted Model Output)"],
"KNOWLEDGE": [],
"DEBUG_RAW_OUTPUT": raw
}
return {
"SKILL": data.get("SKILL", []),
"KNOWLEDGE": data.get("KNOWLEDGE", [])
}
st.title("Skill/Knowledge Extractor")
text = st.text_area("Paste text")
if st.button("Extract") and text.strip():
st.json(extract(text))