Spaces:
Sleeping
Sleeping
| import os | |
| import fitz | |
| import json | |
| import gradio as gr | |
| import pytesseract | |
| import chromadb | |
| import torch | |
| import nltk | |
| import traceback | |
| import docx2txt | |
| from PIL import Image | |
| from io import BytesIO | |
| from tqdm import tqdm | |
| from transformers import AutoProcessor, AutoModelForVision2Seq | |
| from sentence_transformers import SentenceTransformer, util | |
| from nltk.tokenize import sent_tokenize | |
| # Ensure punkt is downloaded | |
| try: | |
| nltk.data.find("tokenizers/punkt") | |
| except LookupError: | |
| nltk.download("punkt") | |
| # Configuration | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| MANUALS_DIR = "Manuals" | |
| CHROMA_PATH = "chroma_store" | |
| COLLECTION_NAME = "manual_chunks" | |
| CHUNK_SIZE = 750 | |
| CHUNK_OVERLAP = 100 | |
| MAX_CONTEXT_CHUNKS = 3 | |
| MODEL_ID = "ibm-granite/granite-vision-3.2-2b" | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ---------------- Text Helpers ---------------- | |
| def clean(text): | |
| return "\n".join([line.strip() for line in text.splitlines() if line.strip()]) | |
| def split_sentences(text): | |
| try: | |
| return sent_tokenize(text) | |
| except: | |
| print("Tokenizer fallback: simple split.") | |
| return text.split(". ") | |
| def split_chunks(sentences, max_tokens=CHUNK_SIZE, overlap=CHUNK_OVERLAP): | |
| chunks = [] | |
| current_chunk, length = [], 0 | |
| for sent in sentences: | |
| words = sent.split() | |
| if length + len(words) > max_tokens and current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| current_chunk = current_chunk[-overlap:] | |
| length = sum(len(s.split()) for s in current_chunk) | |
| current_chunk.append(sent) | |
| length += len(words) | |
| if current_chunk: | |
| chunks.append(" ".join(current_chunk)) | |
| return chunks | |
| # ---------------- File Readers ---------------- | |
| def extract_pdf_text(path): | |
| chunks = [] | |
| try: | |
| doc = fitz.open(path) | |
| for i, page in enumerate(doc): | |
| text = page.get_text().strip() | |
| if not text: | |
| img = Image.open(BytesIO(page.get_pixmap(dpi=300).tobytes("png"))) | |
| text = pytesseract.image_to_string(img) | |
| chunks.append((path, i + 1, clean(text))) | |
| except Exception as e: | |
| print("PDF read error:", path, e) | |
| return chunks | |
| def extract_docx_text(path): | |
| try: | |
| return [(path, 1, clean(docx2txt.process(path)))] | |
| except Exception as e: | |
| print("DOCX read error:", path, e) | |
| return [] | |
| # ---------------- Embedding ---------------- | |
| def embed_all(): | |
| try: | |
| embedder = SentenceTransformer("all-MiniLM-L6-v2") | |
| embedder.eval() | |
| client = chromadb.PersistentClient(path=CHROMA_PATH) | |
| try: | |
| client.delete_collection(COLLECTION_NAME) | |
| except: | |
| pass | |
| collection = client.get_or_create_collection(COLLECTION_NAME) | |
| docs, ids, metas = [], [], [] | |
| print("Processing manuals...") | |
| for fname in os.listdir(MANUALS_DIR): | |
| fpath = os.path.join(MANUALS_DIR, fname) | |
| if fname.lower().endswith(".pdf"): | |
| pages = extract_pdf_text(fpath) | |
| elif fname.lower().endswith(".docx"): | |
| pages = extract_docx_text(fpath) | |
| else: | |
| continue | |
| for path, page, text in pages: | |
| for i, chunk in enumerate(split_chunks(split_sentences(text))): | |
| chunk_id = f"{fname}::{page}::{i}" | |
| docs.append(chunk) | |
| ids.append(chunk_id) | |
| metas.append({"source": fname, "page": page}) | |
| if len(docs) >= 32: # Increased batch size for efficiency | |
| embs = embedder.encode(docs).tolist() | |
| collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs) | |
| docs, ids, metas = [], [], [] | |
| if docs: | |
| embs = embedder.encode(docs).tolist() | |
| collection.add(documents=docs, ids=ids, metadatas=metas, embeddings=embs) | |
| print(f"Embedded {len(ids)} chunks.") | |
| return collection, embedder | |
| except Exception as e: | |
| print("Embedding startup failed:", e) | |
| return None, None | |
| # ---------------- Model Setup ---------------- | |
| def load_model(): | |
| try: | |
| processor = AutoProcessor.from_pretrained(MODEL_ID, token=HF_TOKEN) | |
| model = AutoModelForVision2Seq.from_pretrained(MODEL_ID, token=HF_TOKEN).to(device) | |
| return model, processor | |
| except Exception as e: | |
| print("Model loading failed:", e) | |
| return None, None | |
| def ask_model(question, context, model, processor): | |
| prompt = f"""Use only the following context to answer. If uncertain, say \"I don't know.\" | |
| <context> | |
| {context} | |
| </context> | |
| Q: {question} | |
| A:""" | |
| inputs = processor(prompt, return_tensors="pt").to(device) | |
| output = model.generate(**inputs) | |
| return processor.decode(output[0], skip_special_tokens=True) | |
| # ---------------- Query ---------------- | |
| def get_answer(question): | |
| if not embedder or not db or not model: | |
| return "System not ready. Try again after initialization." | |
| try: | |
| query_emb = embedder.encode(question, convert_to_tensor=True) | |
| results = db.query(query_texts=[question], n_results=MAX_CONTEXT_CHUNKS) | |
| context = "\n\n".join(results["documents"][0]) | |
| return ask_model(question, context, model, processor) | |
| except Exception as e: | |
| print("Query error:", e) | |
| return f"Error: {e}" | |
| # ---------------- UI ---------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## SmartManuals-AI (Granite 3.2-2B)") | |
| with gr.Row(): | |
| question = gr.Textbox(label="Ask your question") | |
| ask = gr.Button("Ask") | |
| answer = gr.Textbox(label="Answer", lines=8) | |
| ask.click(fn=get_answer, inputs=question, outputs=answer) | |
| # Startup Initialization | |
| embedder = None | |
| model = None | |
| processor = None | |
| try: | |
| db, embedder = embed_all() | |
| except Exception as e: | |
| print("❌ Embedding failed:", e) | |
| try: | |
| model, processor = load_model() | |
| except Exception as e: | |
| print("❌ Model load failed:", e) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |