Leo / app.py
amritn8's picture
Update app.py
5f22750 verified
import os
import streamlit as st
from transformers import AutoModelForQuestionAnswering, AutoTokenizer
import torch
# Hugging Face Spaces configuration
MODEL_DIR = os.path.join(os.getcwd(), "qa_model")
CACHE_DIR = os.path.join(os.getcwd(), "cache")
# Set environment variables
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HOME"] = CACHE_DIR
os.environ["XDG_CACHE_HOME"] = CACHE_DIR
# Create directories with proper permissions
os.makedirs(MODEL_DIR, exist_ok=True)
os.makedirs(CACHE_DIR, exist_ok=True)
@st.cache_resource
def load_model():
try:
# Check for existing model files
required_files = ["config.json", "pytorch_model.bin"]
if all(os.path.exists(os.path.join(MODEL_DIR, f)) for f in required_files):
return (
AutoModelForQuestionAnswering.from_pretrained(MODEL_DIR),
AutoTokenizer.from_pretrained(MODEL_DIR)
)
# Download model if missing
MODEL_NAME = "distilbert-base-cased-distilled-squad" # Smaller model for Spaces
with st.spinner("πŸš€ Downloading model (first run only)..."):
model = AutoModelForQuestionAnswering.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
cache_dir=CACHE_DIR
)
# Save for future runs
model.save_pretrained(MODEL_DIR)
tokenizer.save_pretrained(MODEL_DIR)
return model, tokenizer
except Exception as e:
st.error(f"❌ Model loading failed: {str(e)}")
st.stop()
# Load model
model, tokenizer = load_model()
def get_answer(question, context):
inputs = tokenizer(
question, context,
max_length=384, # Reduced for Spaces memory limits
truncation=True,
padding="max_length",
return_tensors="pt"
)
with torch.no_grad():
outputs = model(**inputs)
answer_start = torch.argmax(outputs.start_logits)
answer_end = torch.argmax(outputs.end_logits) + 1
answer = tokenizer.decode(
inputs["input_ids"][0][answer_start:answer_end],
skip_special_tokens=True
)
return answer.strip()
# UI
st.title("πŸ€– QA System on Hugging Face")
context = st.text_area("πŸ“ Enter context", height=200)
question = st.text_input("❓ Your question")
if st.button("πŸ” Get Answer"):
if context and question:
with st.spinner("Analyzing..."):
try:
answer = get_answer(question, context)
st.success(f"βœ… Answer: {answer}" if answer else "⚠️ No clear answer found")
except Exception as e:
st.error(f"Error: {str(e)}")
else:
st.warning("Please provide both context and question")