Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import fitz # PyMuPDF | |
| import torch | |
| import os | |
| import onnxruntime as ort | |
| # --- IMPORT SESSION OPTIONS --- | |
| from onnxruntime import SessionOptions, GraphOptimizationLevel | |
| # --- LANGCHAIN & RAG IMPORTS --- | |
| from langchain_text_splitters import RecursiveCharacterTextSplitter | |
| from langchain_community.vectorstores import FAISS | |
| from langchain_core.embeddings import Embeddings | |
| # --- ONNX & MODEL IMPORTS --- | |
| from transformers import AutoTokenizer | |
| from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForCausalLM | |
| from huggingface_hub import snapshot_download | |
| # Force CPU Provider | |
| PROVIDERS = ["CPUExecutionProvider"] | |
| print(f"⚡ Running on: {PROVIDERS}") | |
| # --------------------------------------------------------- | |
| # 1. OPTIMIZED EMBEDDINGS (BGE-SMALL) | |
| # --------------------------------------------------------- | |
| class OnnxBgeEmbeddings(Embeddings): | |
| def __init__(self): | |
| model_name = "Xenova/bge-small-en-v1.5" | |
| print(f"🔄 Loading Embeddings: {model_name}...") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| self.model = ORTModelForFeatureExtraction.from_pretrained( | |
| model_name, | |
| export=False, | |
| provider=PROVIDERS[0] | |
| ) | |
| def _process_batch(self, texts): | |
| inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| embeddings = outputs.last_hidden_state[:, 0] | |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
| return embeddings.numpy().tolist() | |
| def embed_documents(self, texts): | |
| return self._process_batch(texts) | |
| def embed_query(self, text): | |
| return self._process_batch(["Represent this sentence for searching relevant passages: " + text])[0] | |
| # --------------------------------------------------------- | |
| # 2. OPTIMIZED LLM (Qwen 2.5 - 0.5B) - STRICT GRADING | |
| # --------------------------------------------------------- | |
| class LLMEvaluator: | |
| def __init__(self): | |
| self.repo_id = "onnx-community/Qwen2.5-0.5B-Instruct" | |
| self.local_dir = "onnx_qwen_local" | |
| print(f"🔄 Preparing CPU LLM: {self.repo_id}...") | |
| if not os.path.exists(self.local_dir): | |
| print(f"📥 Downloading FP16 model to {self.local_dir}...") | |
| snapshot_download( | |
| repo_id=self.repo_id, | |
| local_dir=self.local_dir, | |
| allow_patterns=["config.json", "generation_config.json", "tokenizer*", "special_tokens_map.json", "*.jinja", "onnx/model_fp16.onnx*"] | |
| ) | |
| print("✅ Download complete.") | |
| self.tokenizer = AutoTokenizer.from_pretrained(self.local_dir) | |
| sess_options = SessionOptions() | |
| sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL | |
| self.model = ORTModelForCausalLM.from_pretrained( | |
| self.local_dir, | |
| subfolder="onnx", | |
| file_name="model_fp16.onnx", | |
| use_cache=True, | |
| use_io_binding=False, | |
| provider=PROVIDERS[0], | |
| session_options=sess_options | |
| ) | |
| def evaluate(self, context, question, student_answer, max_marks): | |
| # OPTIMIZED PROMPT FOR SMALL MODELS (0.5B) | |
| messages = [ | |
| {"role": "system", "content": "You are a strictest, literal academic grader. You ONLY grade based on the provided text. You DO NOT use outside knowledge."}, | |
| {"role": "user", "content": f""" | |
| Task: Grade the student answer based ONLY on the Reference Text. | |
| REFERENCE TEXT: | |
| {context} | |
| QUESTION: | |
| {question} | |
| STUDENT ANSWER: | |
| {student_answer} | |
| ----------------------------- | |
| GRADING LOGIC: | |
| 1. READ the Reference Text. What does it actually say about the Question? | |
| 2. COMPARE it to the Student Answer. | |
| 3 START with 0 marks and IF the answers line up to the reference text in a meaningful way, then add marks porportionally. ONLY GIVE MARKS FOR CORRECT STATEMENT STRICTLY BASED ON THE REFERENCE TEXT AND NOTHING ELSE IN THIS WORLD. | |
| 4. IF the Student Answer claims things not found in the text , he is incorrect and HALLUCINATING. Do not give marks for that statment/phrase | |
| 5. IF the Student Answer contradicts the text (e.g., Text says "hide personality" but Student says "show personality"), Do not give marks for that statment/phrase | |
| VERDICT: | |
| - If wrong: 0/{max_marks} | |
| - If correct: {max_marks}/{max_marks} | |
| OUTPUT FORMAT: | |
| Score: [X]/{max_marks} | |
| Feedback: [Brief explanation citing the text] | |
| """} | |
| ] | |
| input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) | |
| inputs = self.tokenizer(input_text, return_tensors="pt") | |
| with torch.no_grad(): | |
| outputs = self.model.generate( | |
| **inputs, | |
| max_new_tokens=100, | |
| temperature=0.05, # 0.0 = logic only, no creativity | |
| do_sample=False, | |
| repetition_penalty=1.2 | |
| ) | |
| input_length = inputs['input_ids'].shape[1] | |
| response = self.tokenizer.decode(outputs[0][input_length:], skip_special_tokens=True) | |
| return response | |
| # --------------------------------------------------------- | |
| # 3. Main Application Logic | |
| # --------------------------------------------------------- | |
| class VectorSystem: | |
| def __init__(self): | |
| self.vector_store = None | |
| self.embeddings = OnnxBgeEmbeddings() | |
| self.llm = LLMEvaluator() | |
| self.all_chunks = [] | |
| self.total_chunks = 0 | |
| def process_file(self, file_obj): | |
| if file_obj is None: return "No file uploaded." | |
| try: | |
| text = "" | |
| if file_obj.name.endswith('.pdf'): | |
| doc = fitz.open(file_obj.name) | |
| for page in doc: text += page.get_text() | |
| elif file_obj.name.endswith('.txt'): | |
| with open(file_obj.name, 'r', encoding='utf-8') as f: text = f.read() | |
| else: | |
| return "❌ Error: Only .pdf and .txt supported." | |
| text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100) | |
| self.all_chunks = text_splitter.split_text(text) | |
| self.total_chunks = len(self.all_chunks) | |
| if not self.all_chunks: return "File empty." | |
| metadatas = [{"id": i} for i in range(self.total_chunks)] | |
| self.vector_store = FAISS.from_texts(self.all_chunks, self.embeddings, metadatas=metadatas) | |
| return f"✅ Indexed {self.total_chunks} chunks." | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def process_query(self, question, student_answer, max_marks): | |
| if not self.vector_store: return "⚠️ Please upload a file first.", "" | |
| if not question: return "⚠️ Enter a question.", "" | |
| results = self.vector_store.similarity_search_with_score(question, k=1) | |
| top_doc, score = results[0] | |
| center_id = top_doc.metadata['id'] | |
| start_id = max(0, center_id - 1) | |
| end_id = min(self.total_chunks - 1, center_id + 1) | |
| expanded_context = "" | |
| for i in range(start_id, end_id + 1): | |
| expanded_context += self.all_chunks[i] + "\n" | |
| evidence_display = f"### 📚 Expanded Context (Chunks {start_id} to {end_id}):\n" | |
| evidence_display += f"> ... {expanded_context} ..." | |
| llm_feedback = "Please enter a student answer to grade." | |
| if student_answer: | |
| llm_feedback = self.llm.evaluate(expanded_context, question, student_answer, max_marks) | |
| return evidence_display, llm_feedback | |
| system = VectorSystem() | |
| with gr.Blocks(title="EduGenius AI Grader") as demo: | |
| gr.Markdown("# ⚡ EduGenius: CPU Optimized RAG") | |
| gr.Markdown("Powered by **Qwen-2.5-0.5B** and **BGE-Small** (ONNX Optimized)") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| pdf_input = gr.File(label="1. Upload Chapter") | |
| upload_btn = gr.Button("Index Content", variant="primary") | |
| status_msg = gr.Textbox(label="Status", interactive=False) | |
| with gr.Column(scale=2): | |
| with gr.Row(): | |
| q_input = gr.Textbox(label="Question", scale=2) | |
| max_marks = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Max Marks") | |
| a_input = gr.TextArea(label="Student Answer") | |
| run_btn = gr.Button("Retrieve & Grade", variant="secondary") | |
| with gr.Row(): | |
| evidence_box = gr.Markdown(label="Context Used") | |
| grade_box = gr.Markdown(label="Grading Result") | |
| upload_btn.click(system.process_file, inputs=[pdf_input], outputs=[status_msg]) | |
| run_btn.click(system.process_query, inputs=[q_input, a_input, max_marks], outputs=[evidence_box, grade_box]) | |
| if __name__ == "__main__": | |
| demo.launch() |