| import os |
| import streamlit as st |
| from transformers import AutoModelForQuestionAnswering, AutoTokenizer |
| import torch |
|
|
| |
| MODEL_DIR = os.path.join(os.getcwd(), "qa_model") |
| CACHE_DIR = os.path.join(os.getcwd(), "cache") |
|
|
| |
| os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR |
| os.environ["HF_HOME"] = CACHE_DIR |
| os.environ["XDG_CACHE_HOME"] = CACHE_DIR |
|
|
| |
| os.makedirs(MODEL_DIR, exist_ok=True) |
| os.makedirs(CACHE_DIR, exist_ok=True) |
|
|
| @st.cache_resource |
| def load_model(): |
| try: |
| |
| required_files = ["config.json", "pytorch_model.bin"] |
| if all(os.path.exists(os.path.join(MODEL_DIR, f)) for f in required_files): |
| return ( |
| AutoModelForQuestionAnswering.from_pretrained(MODEL_DIR), |
| AutoTokenizer.from_pretrained(MODEL_DIR) |
| ) |
| |
| |
| MODEL_NAME = "distilbert-base-cased-distilled-squad" |
| with st.spinner("π Downloading model (first run only)..."): |
| model = AutoModelForQuestionAnswering.from_pretrained( |
| MODEL_NAME, |
| cache_dir=CACHE_DIR |
| ) |
| tokenizer = AutoTokenizer.from_pretrained( |
| MODEL_NAME, |
| cache_dir=CACHE_DIR |
| ) |
| |
| |
| model.save_pretrained(MODEL_DIR) |
| tokenizer.save_pretrained(MODEL_DIR) |
| |
| return model, tokenizer |
| |
| except Exception as e: |
| st.error(f"β Model loading failed: {str(e)}") |
| st.stop() |
|
|
| |
| model, tokenizer = load_model() |
|
|
| def get_answer(question, context): |
| inputs = tokenizer( |
| question, context, |
| max_length=384, |
| truncation=True, |
| padding="max_length", |
| return_tensors="pt" |
| ) |
| |
| with torch.no_grad(): |
| outputs = model(**inputs) |
| |
| answer_start = torch.argmax(outputs.start_logits) |
| answer_end = torch.argmax(outputs.end_logits) + 1 |
| |
| answer = tokenizer.decode( |
| inputs["input_ids"][0][answer_start:answer_end], |
| skip_special_tokens=True |
| ) |
| return answer.strip() |
|
|
| |
| st.title("π€ QA System on Hugging Face") |
| context = st.text_area("π Enter context", height=200) |
| question = st.text_input("β Your question") |
|
|
| if st.button("π Get Answer"): |
| if context and question: |
| with st.spinner("Analyzing..."): |
| try: |
| answer = get_answer(question, context) |
| st.success(f"β
Answer: {answer}" if answer else "β οΈ No clear answer found") |
| except Exception as e: |
| st.error(f"Error: {str(e)}") |
| else: |
| st.warning("Please provide both context and question") |