llmrag / src /streamlit_app.py
lol040604lol's picture
Update src/streamlit_app.py
dfc3404 verified
import streamlit as st
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import faiss
import numpy as np
import os
st.set_page_config(page_title="🛕 Tamil RAG Expert", layout="wide")
@st.cache_resource
def load_model():
cache_dir = "./hf_model_cache"
os.makedirs(cache_dir, exist_ok=True)
model = AutoModelForCausalLM.from_pretrained("flax-community/gpt-2-tamil", cache_dir=cache_dir)
tokenizer = AutoTokenizer.from_pretrained("flax-community/gpt-2-tamil", cache_dir=cache_dir)
return tokenizer, model
tokenizer, model = load_model()
# ... (same rest of app)
def embed_text(texts):
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True, max_length=128)
with torch.no_grad():
outputs = model.transformer.wte(inputs['input_ids']) # token embeddings
avg_embeddings = outputs.mean(dim=1).cpu().numpy()
return avg_embeddings
def build_faiss_index(docs):
vectors = embed_text(docs)
dim = vectors.shape[1]
index = faiss.IndexFlatL2(dim)
index.add(vectors)
return index, docs
st.title("📜 Tamil Ancient Text Expert (LLM + RAG)")
uploaded_file = st.file_uploader("Upload your knowledge base (context.txt)", type="txt")
if uploaded_file:
raw_text = uploaded_file.read().decode("utf-8")
docs = [line.strip() for line in raw_text.split("\n") if line.strip()]
index, db_docs = build_faiss_index(docs)
st.success("Context loaded! You can now ask questions.")
query = st.text_area("Enter corrupted Tamil text or question:", height=200)
if st.button("🧠 Generate Response"):
query_vec = embed_text([query])
D, I = index.search(query_vec, k=3)
retrieved = "\n".join([db_docs[i] for i in I[0]])
prompt = f"உரையின் சுந்தரம் மற்றும் பொருளுடன் பின்வரும் உள்ளடக்கம் மற்றும் வினாவை பயன்படுத்தி பதில் அளிக்கவும்:\n\n{retrieved}\n\nவினா:\n{query}\n\nபதில்:\n"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True, max_length=256)
outputs = model.generate(inputs["input_ids"], max_length=300, do_sample=True, top_p=0.95, top_k=40)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
st.markdown("### ✅ Tamil Response:")
st.write(response)
else:
st.info("Upload a .txt file containing ancient Tamil lines (one per line).")