Spaces:
Sleeping
Sleeping
| # app.py | |
| # SmartManuals-AI: Hugging Face Space version | |
| import os, json, fitz, nltk, chromadb, io | |
| import torch | |
| from tqdm import tqdm | |
| from PIL import Image | |
| from docx import Document | |
| from sentence_transformers import SentenceTransformer, util | |
| from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
| from nltk.tokenize import sent_tokenize | |
| import pytesseract | |
| import gradio as gr | |
| # ---------------------- | |
| # Configuration | |
| # ---------------------- | |
| MANUALS_FOLDER = "./Manuals" | |
| CHUNKS_JSONL = "chunks.jsonl" | |
| CHROMA_PATH = "./chroma_store" | |
| COLLECTION_NAME = "manual_chunks" | |
| CHUNK_SIZE = 750 | |
| CHUNK_OVERLAP = 100 | |
| MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct" | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| # ---------------------- | |
| # Ensure punkt is downloaded | |
| # ---------------------- | |
| nltk.download("punkt") | |
| # ---------------------- | |
| # Utilities | |
| # ---------------------- | |
| def extract_text_from_pdf(path): | |
| doc = fitz.open(path) | |
| text = "" | |
| for page in doc: | |
| t = page.get_text() | |
| if not t.strip(): | |
| pix = page.get_pixmap(dpi=300) | |
| img = Image.open(io.BytesIO(pix.tobytes("png"))) | |
| t = pytesseract.image_to_string(img) | |
| text += t + "\n" | |
| return text | |
| def extract_text_from_docx(path): | |
| doc = Document(path) | |
| return "\n".join(p.text for p in doc.paragraphs if p.text.strip()) | |
| def clean(text): | |
| return "\n".join([line.strip() for line in text.splitlines() if line.strip()]) | |
| def split_sentences(text): | |
| return sent_tokenize(text) | |
| def chunk_sentences(sentences, max_tokens=CHUNK_SIZE, overlap=CHUNK_OVERLAP): | |
| chunks, chunk, count = [], [], 0 | |
| for s in sentences: | |
| words = s.split() | |
| if count + len(words) > max_tokens: | |
| chunks.append(" ".join(chunk)) | |
| chunk = chunk[-overlap:] if overlap > 0 else [] | |
| count = sum(len(x.split()) for x in chunk) | |
| chunk.append(s) | |
| count += len(words) | |
| if chunk: | |
| chunks.append(" ".join(chunk)) | |
| return chunks | |
| def get_metadata(filename): | |
| name = filename.lower() | |
| return { | |
| "source_file": filename, | |
| "doc_type": "service manual" if "sm" in name else "owner manual" if "om" in name else "unknown", | |
| "model": "se3hd" if "se3hd" in name else "unknown" | |
| } | |
| # ---------------------- | |
| # Embedding | |
| # ---------------------- | |
| def embed_all(): | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| try: | |
| client.delete_collection(COLLECTION_NAME) | |
| except: | |
| pass | |
| collection = client.create_collection(COLLECTION_NAME) | |
| chunks, metadatas, ids = [], [], [] | |
| files = os.listdir(MANUALS_FOLDER) | |
| idx = 0 | |
| for file in tqdm(files): | |
| path = os.path.join(MANUALS_FOLDER, file) | |
| text = extract_text_from_pdf(path) if file.endswith(".pdf") else extract_text_from_docx(path) | |
| meta = get_metadata(file) | |
| sents = split_sentences(clean(text)) | |
| for i, chunk in enumerate(chunk_sentences(sents)): | |
| chunks.append(chunk) | |
| ids.append(f"{file}::chunk_{i}") | |
| metadatas.append(meta) | |
| if len(chunks) >= 16: | |
| emb = embedder.encode(chunks).tolist() | |
| collection.add(documents=chunks, ids=ids, metadatas=metadatas, embeddings=emb) | |
| chunks, ids, metadatas = [], [], [] | |
| if chunks: | |
| emb = embedder.encode(chunks).tolist() | |
| collection.add(documents=chunks, ids=ids, metadatas=metadatas, embeddings=emb) | |
| return collection, embedder | |
| # ---------------------- | |
| # Model setup | |
| # ---------------------- | |
| def load_model(): | |
| device = 0 if torch.cuda.is_available() else -1 | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, token=HF_TOKEN) | |
| return pipeline("text-generation", model=model, tokenizer=tokenizer, device=device, max_new_tokens=512) | |
| # ---------------------- | |
| # RAG Pipeline | |
| # ---------------------- | |
| def answer_query(question): | |
| results = db.query(query_texts=[question], n_results=5) | |
| context = "\n\n".join(results["documents"][0]) | |
| prompt = f"""<|begin_of_text|><|start_header_id|>system<|end_header_id|> | |
| You are a helpful assistant. Use the provided context to answer questions. If you don't know, say 'I don't know.' | |
| <context> | |
| {context} | |
| </context> | |
| <|start_header_id|>user<|end_header_id|> | |
| {question}<|start_header_id|>assistant<|end_header_id|>""" | |
| return llm(prompt)[0]["generated_text"].split("<|start_header_id|>assistant<|end_header_id|>")[-1].strip() | |
| # ---------------------- | |
| # UI | |
| # ---------------------- | |
| with gr.Blocks() as demo: | |
| status = gr.Textbox(label="Status", value="Embedding manuals... Please wait.", interactive=False) | |
| question = gr.Textbox(label="Ask a Question") | |
| submit = gr.Button("🔍 Ask") | |
| answer = gr.Textbox(label="Answer", lines=8) | |
| def handle_query(q): | |
| return answer_query(q) | |
| submit.click(fn=handle_query, inputs=question, outputs=answer) | |
| # ---------------------- | |
| # Startup | |
| # ---------------------- | |
| status_text = "Embedding manuals and loading model..." | |
| db, embedder = embed_all() | |
| llm = load_model() | |
| status_text = "Ready!" | |
| demo.launch() | |