Spaces:
Sleeping
Sleeping
| import torch | |
| import json | |
| import os | |
| import faiss | |
| import numpy as np | |
| from pptx import Presentation | |
| from fastapi import FastAPI, UploadFile, File | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| from sentence_transformers import SentenceTransformer | |
| from io import BytesIO | |
| import gradio as gr | |
| from gradio import mount_gradio_app | |
| # ---------------------------- # | |
| # CONFIGURATION | |
| # ---------------------------- # | |
| MODEL_NAME = "./models/facebook-opt-1.3b" | |
| SUMMARIZATION_MODEL = "./models/bart-large-cnn" | |
| EMBEDDING_MODEL = "./models/all-MiniLM-L6-v2" | |
| DATA_DIRECTORY = "./dataset/" | |
| # ---------------------------- # | |
| # FUNCTION TO LOAD JSON FILES | |
| # ---------------------------- # | |
| def load_text_from_json(directory): | |
| text_data = set() # Use set to remove duplicates | |
| for filename in os.listdir(directory): | |
| if filename.endswith(".json"): | |
| with open(os.path.join(directory, filename), "r", encoding="utf-8") as file: | |
| data = json.load(file) | |
| for entry in data.get("data", []): | |
| question = entry.get("question", "").strip() | |
| answer = entry.get("answer", "").strip() | |
| if question and answer: | |
| text_data.add(f"Q: {question} A: {answer}") | |
| return list(text_data) | |
| # ---------------------------- # | |
| # FUNCTION TO LOAD POWERPOINT FILES | |
| # ---------------------------- # | |
| def extract_text_from_pptx(file_path): | |
| prs = Presentation(file_path) | |
| text_data = [] | |
| for slide in prs.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text"): | |
| text_data.append(shape.text.strip()) | |
| return " ".join(text_data) | |
| def load_text_from_pptx(directory): | |
| text_data = set() | |
| for filename in os.listdir(directory): | |
| if filename.endswith(".pptx"): | |
| pptx_text = extract_text_from_pptx(os.path.join(directory, filename)) | |
| text_data.add(pptx_text) | |
| return list(text_data) | |
| # ---------------------------- # | |
| # LOAD ALL TEXT DATA | |
| # ---------------------------- # | |
| all_text = load_text_from_json(DATA_DIRECTORY) + load_text_from_pptx(DATA_DIRECTORY) | |
| # ---------------------------- # | |
| # CHUNK DATA PROPERLY | |
| # ---------------------------- # | |
| CHUNK_SIZE = 500 | |
| chunks = set() | |
| for text in all_text: | |
| sentences = text.split(". ") | |
| temp_chunk = "" | |
| for sentence in sentences: | |
| if len(temp_chunk) + len(sentence) < CHUNK_SIZE: | |
| temp_chunk += sentence + ". " | |
| else: | |
| chunks.add(temp_chunk.strip()) # Store chunk | |
| temp_chunk = sentence + ". " | |
| if temp_chunk: | |
| chunks.add(temp_chunk.strip()) # Store last chunk | |
| chunks = list(chunks) # Convert to list after deduplication | |
| # ---------------------------- # | |
| # EMBEDDING MODEL & FAISS VECTOR SEARCH | |
| # ---------------------------- # | |
| embedder = SentenceTransformer(EMBEDDING_MODEL, local_files_only=True) | |
| chunk_embeddings = embedder.encode(chunks, convert_to_numpy=True) | |
| # FAISS index | |
| index = faiss.IndexFlatL2(chunk_embeddings.shape[1]) | |
| index.add(chunk_embeddings) | |
| # ---------------------------- # | |
| # LOAD LLM MODEL | |
| # ---------------------------- # | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, trust_remote_code=True, torch_dtype=torch.float32, device_map="cpu" | |
| ) | |
| # Summarization pipeline | |
| summarizer = pipeline("summarization", model=SUMMARIZATION_MODEL) | |
| # ---------------------------- # | |
| # FASTAPI SETUP | |
| # ---------------------------- # | |
| app = FastAPI() | |
| def retrieve_relevant_text(question, top_k=3): | |
| question_embedding = embedder.encode([question], convert_to_numpy=True) | |
| _, idxs = index.search(question_embedding, top_k) | |
| retrieved_texts = [chunks[idx] for idx in idxs[0]] | |
| # Filter out chunks that contain the same question | |
| filtered_chunks = [text for text in retrieved_texts if question.lower() not in text.lower()] | |
| unique_texts = list(set(filtered_chunks)) | |
| context_text = " ".join(unique_texts) | |
| if len(context_text) > 1000: | |
| context_text = summarizer(context_text, max_length=150, min_length=50, do_sample=False)[0]["summary_text"] | |
| return context_text | |
| async def upload_file(file: UploadFile = File(...)): | |
| global chunks, index, chunk_embeddings | |
| filename = file.filename | |
| content = await file.read() | |
| new_texts = [] | |
| try: | |
| # -------------------- # | |
| # Process .json files | |
| # -------------------- # | |
| if filename.endswith(".json"): | |
| data = json.loads(content) | |
| for entry in data.get("data", []): | |
| question = entry.get("question", "").strip() | |
| answer = entry.get("answer", "").strip() | |
| if question and answer: | |
| new_texts.append(f"Q: {question} A: {answer}") | |
| # -------------------- # | |
| # Process .pptx files | |
| # -------------------- # | |
| elif filename.endswith(".pptx"): | |
| prs = Presentation(BytesIO(content)) | |
| ppt_text = [] | |
| for slide in prs.slides: | |
| for shape in slide.shapes: | |
| if hasattr(shape, "text"): | |
| ppt_text.append(shape.text.strip()) | |
| new_texts.append(" ".join(ppt_text)) | |
| else: | |
| return {"error": "Unsupported file type. Use .json or .pptx"} | |
| # -------------------- # | |
| # Chunk and embed | |
| # -------------------- # | |
| new_chunks = set() | |
| for text in new_texts: | |
| sentences = text.split(". ") | |
| temp = "" | |
| for s in sentences: | |
| if len(temp) + len(s) < CHUNK_SIZE: | |
| temp += s + ". " | |
| else: | |
| new_chunks.add(temp.strip()) | |
| temp = s + ". " | |
| if temp: | |
| new_chunks.add(temp.strip()) | |
| # Remove existing chunks (dedup) | |
| new_chunks = list(set(new_chunks) - set(chunks)) | |
| if not new_chunks: | |
| return {"message": "No new unique chunks to add."} | |
| # Encode and update FAISS | |
| new_embeddings = embedder.encode(new_chunks, convert_to_numpy=True) | |
| index.add(new_embeddings) | |
| chunks.extend(new_chunks) | |
| return { | |
| "status": "success", | |
| "new_chunks_added": len(new_chunks), | |
| "total_chunks": len(chunks) | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| def faq(question: str): | |
| """Answer user queries using retrieved knowledge.""" | |
| retrieved_text = retrieve_relevant_text(question) | |
| prompt = ( | |
| f"{retrieved_text.strip()}\n\n" | |
| f"Answer the following question based only on the above context:\n" | |
| f"{question.strip()}\n\n" | |
| f"Answer:" | |
| ) | |
| inputs = tokenizer(prompt, return_tensors="pt").to("cpu") | |
| with torch.no_grad(): | |
| output = model.generate( | |
| **inputs, | |
| max_length=512, | |
| repetition_penalty=1.3, | |
| no_repeat_ngram_size=4, | |
| temperature=0.7, | |
| do_sample=False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| raw_answer = tokenizer.decode(output[0], skip_special_tokens=True) | |
| # ---------------------------- # | |
| # POST-PROCESSING CLEANUP | |
| # ---------------------------- # | |
| cleaned_answer = raw_answer | |
| # Remove the prompt (everything before final 'Answer:' keyword) | |
| if "Answer:" in cleaned_answer: | |
| cleaned_answer = cleaned_answer.split("Answer:")[-1] | |
| # Remove repeated question (case-insensitive) | |
| question_lower = question.strip().lower() | |
| cleaned_answer = cleaned_answer.strip() | |
| if cleaned_answer.lower().startswith(question_lower): | |
| cleaned_answer = cleaned_answer[len(question):].strip() | |
| # Final touch: remove context/prompt tokens if they leaked | |
| for token in ["Context:", "Question:", "Answer:"]: | |
| cleaned_answer = cleaned_answer.replace(token, "").strip() | |
| return {"answer": cleaned_answer} | |
| # --------- Gradio UI --------- # | |
| def gradio_upload(file): | |
| if file is None: | |
| return "No file selected." | |
| try: | |
| import requests | |
| base_url = os.getenv("HF_SPACE_URL", "http://localhost:7860") | |
| # file is a NamedString — open it by its name | |
| with open(file.name, "rb") as f: | |
| files = {"file": (os.path.basename(file.name), f)} | |
| response = requests.post(f"{base_url}/upload/", files=files) | |
| if response.status_code == 200: | |
| return "✅ Data successfully uploaded and indexed!" | |
| else: | |
| return f"❌ Failed: {response.text}" | |
| except Exception as e: | |
| return f"❌ Error: {str(e)}" | |
| gr_app = gr.Interface( | |
| fn=gradio_upload, | |
| inputs=gr.File(label="Upload .txt or .json file"), | |
| outputs="text", | |
| title="Upload Knowledge", | |
| ) | |
| # Mount Gradio at /ui | |
| app = mount_gradio_app(app, gr_app, path="/ui") | |