try_answer / app.py
heerjtdev's picture
Update app.py
ccdc2fe verified
raw
history blame
8.53 kB
import gradio as gr
import fitz # PyMuPDF
import torch
import os
# --- LANGCHAIN & RAG IMPORTS ---
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import FAISS
from langchain_core.embeddings import Embeddings
# --- ONNX & MODEL IMPORTS ---
from transformers import AutoTokenizer
from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForCausalLM
from huggingface_hub import snapshot_download
import onnxruntime as ort
# Check available hardware accelerators
PROVIDERS = ort.get_available_providers()
print(f"⚡ Hardware Acceleration Providers: {PROVIDERS}")
# ---------------------------------------------------------
# 1. OPTIMIZED EMBEDDINGS (BGE-SMALL)
# ---------------------------------------------------------
class OnnxBgeEmbeddings(Embeddings):
def __init__(self):
# FIX 1: Use "Xenova/..." version which has pre-exported ONNX weights.
# The official "BAAI/..." repo is PyTorch-only and fails with export=False.
model_name = "Xenova/bge-small-en-v1.5"
print(f"🔄 Loading Embeddings: {model_name}...")
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = ORTModelForFeatureExtraction.from_pretrained(
model_name,
export=False, # Now safe because Xenova repo has model.onnx
provider=PROVIDERS[0]
)
def _process_batch(self, texts):
inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
device = self.model.device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model(**inputs)
embeddings = outputs.last_hidden_state[:, 0]
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
return embeddings.cpu().numpy().tolist()
def embed_documents(self, texts):
return self._process_batch(texts)
def embed_query(self, text):
return self._process_batch(["Represent this sentence for searching relevant passages: " + text])[0]
# ---------------------------------------------------------
# 2. OPTIMIZED LLM (Qwen 2.5 - 0.5B)
# ---------------------------------------------------------
class LLMEvaluator:
def __init__(self):
# FIX 2: Correct Repo ID for Qwen 2.5 ONNX
self.repo_id = "onnx-community/Qwen2.5-0.5B-Instruct"
self.local_dir = "onnx_qwen_local"
print(f"🔄 Preparing Ultra-Fast LLM: {self.repo_id}...")
if not os.path.exists(self.local_dir):
print(f"📥 Downloading FP16 model + data to {self.local_dir}...")
# We download the 'onnx' subfolder specifically
snapshot_download(
repo_id=self.repo_id,
local_dir=self.local_dir,
allow_patterns=["config.json", "generation_config.json", "tokenizer*", "special_tokens_map.json", "*.jinja", "onnx/model_fp16.onnx*"]
)
print("✅ Download complete.")
self.tokenizer = AutoTokenizer.from_pretrained(self.local_dir)
# FIX 3: Point to the 'onnx' subfolder inside the downloaded directory
self.model = ORTModelForCausalLM.from_pretrained(
self.local_dir,
subfolder="onnx",
file_name="model_fp16.onnx",
use_cache=True,
use_io_binding=True,
provider=PROVIDERS[0]
)
def evaluate(self, context, question, student_answer, max_marks):
messages = [
{"role": "system", "content": "You are a strict academic grader. Verify the student answer against the context. Be harsh. Do not hallucinate."},
{"role": "user", "content": f"""
CONTEXT: {context}
QUESTION: {question}
ANSWER: {student_answer}
TASK: Grade out of {max_marks}.
RULES:
1. If wrong, 0 marks.
2. Be strict.
3. Format: 'Score: X/{max_marks} \n Feedback: ...'
"""}
]
input_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.tokenizer(input_text, return_tensors="pt")
device = self.model.device
inputs = {k: v.to(device) for k, v in inputs.items()}
with torch.no_grad():
outputs = self.model.generate(
**inputs,
max_new_tokens=75,
temperature=0.1,
do_sample=False
)
response = self.tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True)
return response
# ---------------------------------------------------------
# 3. Main Application Logic
# ---------------------------------------------------------
class VectorSystem:
def __init__(self):
self.vector_store = None
self.embeddings = OnnxBgeEmbeddings()
self.llm = LLMEvaluator()
self.all_chunks = []
self.total_chunks = 0
def process_file(self, file_obj):
if file_obj is None: return "No file uploaded."
try:
text = ""
if file_obj.name.endswith('.pdf'):
doc = fitz.open(file_obj.name)
for page in doc: text += page.get_text()
elif file_obj.name.endswith('.txt'):
with open(file_obj.name, 'r', encoding='utf-8') as f: text = f.read()
else:
return "❌ Error: Only .pdf and .txt supported."
text_splitter = RecursiveCharacterTextSplitter(chunk_size=800, chunk_overlap=100)
self.all_chunks = text_splitter.split_text(text)
self.total_chunks = len(self.all_chunks)
if not self.all_chunks: return "File empty."
metadatas = [{"id": i} for i in range(self.total_chunks)]
self.vector_store = FAISS.from_texts(self.all_chunks, self.embeddings, metadatas=metadatas)
return f"✅ Indexed {self.total_chunks} chunks."
except Exception as e:
return f"Error: {str(e)}"
def process_query(self, question, student_answer, max_marks):
if not self.vector_store: return "⚠️ Please upload a file first.", ""
if not question: return "⚠️ Enter a question.", ""
results = self.vector_store.similarity_search_with_score(question, k=1)
top_doc, score = results[0]
center_id = top_doc.metadata['id']
start_id = max(0, center_id - 1)
end_id = min(self.total_chunks - 1, center_id + 1)
expanded_context = ""
for i in range(start_id, end_id + 1):
expanded_context += self.all_chunks[i] + "\n"
evidence_display = f"### 📚 Expanded Context (Chunks {start_id} to {end_id}):\n"
evidence_display += f"> ... {expanded_context} ..."
llm_feedback = "Please enter a student answer to grade."
if student_answer:
llm_feedback = self.llm.evaluate(expanded_context, question, student_answer, max_marks)
return evidence_display, llm_feedback
system = VectorSystem()
with gr.Blocks(title="EduGenius AI Grader") as demo:
gr.Markdown("# ⚡ EduGenius: Ultra-Fast RAG")
gr.Markdown("Powered by **Qwen-2.5-0.5B** and **BGE-Small** (ONNX Optimized)")
with gr.Row():
with gr.Column(scale=1):
pdf_input = gr.File(label="1. Upload Chapter")
upload_btn = gr.Button("Index Content", variant="primary")
status_msg = gr.Textbox(label="Status", interactive=False)
with gr.Column(scale=2):
with gr.Row():
q_input = gr.Textbox(label="Question", scale=2)
max_marks = gr.Slider(minimum=1, maximum=20, value=5, step=1, label="Max Marks")
a_input = gr.TextArea(label="Student Answer")
run_btn = gr.Button("Retrieve & Grade", variant="secondary")
with gr.Row():
evidence_box = gr.Markdown(label="Context Used")
grade_box = gr.Markdown(label="Grading Result")
upload_btn.click(system.process_file, inputs=[pdf_input], outputs=[status_msg])
run_btn.click(system.process_query, inputs=[q_input, a_input, max_marks], outputs=[evidence_box, grade_box])
if __name__ == "__main__":
demo.launch()