heerjtdev commited on
Commit
eab649a
Β·
verified Β·
1 Parent(s): ccdc2fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -16
app.py CHANGED
@@ -2,6 +2,10 @@ import gradio as gr
2
  import fitz # PyMuPDF
3
  import torch
4
  import os
 
 
 
 
5
 
6
  # --- LANGCHAIN & RAG IMPORTS ---
7
  from langchain_text_splitters import RecursiveCharacterTextSplitter
@@ -12,39 +16,31 @@ from langchain_core.embeddings import Embeddings
12
  from transformers import AutoTokenizer
13
  from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForCausalLM
14
  from huggingface_hub import snapshot_download
15
- import onnxruntime as ort
16
 
17
  # Check available hardware accelerators
18
  PROVIDERS = ort.get_available_providers()
19
  print(f"⚑ Hardware Acceleration Providers: {PROVIDERS}")
20
 
21
  # ---------------------------------------------------------
22
- # 1. OPTIMIZED EMBEDDINGS (BGE-SMALL)
23
  # ---------------------------------------------------------
24
  class OnnxBgeEmbeddings(Embeddings):
25
  def __init__(self):
26
- # FIX 1: Use "Xenova/..." version which has pre-exported ONNX weights.
27
- # The official "BAAI/..." repo is PyTorch-only and fails with export=False.
28
  model_name = "Xenova/bge-small-en-v1.5"
29
  print(f"πŸ”„ Loading Embeddings: {model_name}...")
30
-
31
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
32
-
33
  self.model = ORTModelForFeatureExtraction.from_pretrained(
34
  model_name,
35
- export=False, # Now safe because Xenova repo has model.onnx
36
- provider=PROVIDERS[0]
37
  )
38
 
39
  def _process_batch(self, texts):
40
  inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
41
-
42
  device = self.model.device
43
  inputs = {k: v.to(device) for k, v in inputs.items()}
44
-
45
  with torch.no_grad():
46
  outputs = self.model(**inputs)
47
-
48
  embeddings = outputs.last_hidden_state[:, 0]
49
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
50
  return embeddings.cpu().numpy().tolist()
@@ -56,19 +52,18 @@ class OnnxBgeEmbeddings(Embeddings):
56
  return self._process_batch(["Represent this sentence for searching relevant passages: " + text])[0]
57
 
58
  # ---------------------------------------------------------
59
- # 2. OPTIMIZED LLM (Qwen 2.5 - 0.5B)
60
  # ---------------------------------------------------------
61
  class LLMEvaluator:
62
  def __init__(self):
63
- # FIX 2: Correct Repo ID for Qwen 2.5 ONNX
64
  self.repo_id = "onnx-community/Qwen2.5-0.5B-Instruct"
65
  self.local_dir = "onnx_qwen_local"
66
 
67
  print(f"πŸ”„ Preparing Ultra-Fast LLM: {self.repo_id}...")
68
 
 
69
  if not os.path.exists(self.local_dir):
70
  print(f"πŸ“₯ Downloading FP16 model + data to {self.local_dir}...")
71
- # We download the 'onnx' subfolder specifically
72
  snapshot_download(
73
  repo_id=self.repo_id,
74
  local_dir=self.local_dir,
@@ -78,14 +73,19 @@ class LLMEvaluator:
78
 
79
  self.tokenizer = AutoTokenizer.from_pretrained(self.local_dir)
80
 
81
- # FIX 3: Point to the 'onnx' subfolder inside the downloaded directory
 
 
 
 
82
  self.model = ORTModelForCausalLM.from_pretrained(
83
  self.local_dir,
84
  subfolder="onnx",
85
  file_name="model_fp16.onnx",
86
  use_cache=True,
87
  use_io_binding=True,
88
- provider=PROVIDERS[0]
 
89
  )
90
 
91
  def evaluate(self, context, question, student_answer, max_marks):
 
2
  import fitz # PyMuPDF
3
  import torch
4
  import os
5
+ import onnxruntime as ort
6
+
7
+ # --- IMPORT SESSION OPTIONS ---
8
+ from onnxruntime import SessionOptions, GraphOptimizationLevel
9
 
10
  # --- LANGCHAIN & RAG IMPORTS ---
11
  from langchain_text_splitters import RecursiveCharacterTextSplitter
 
16
  from transformers import AutoTokenizer
17
  from optimum.onnxruntime import ORTModelForFeatureExtraction, ORTModelForCausalLM
18
  from huggingface_hub import snapshot_download
 
19
 
20
  # Check available hardware accelerators
21
  PROVIDERS = ort.get_available_providers()
22
  print(f"⚑ Hardware Acceleration Providers: {PROVIDERS}")
23
 
24
  # ---------------------------------------------------------
25
+ # 1. OPTIMIZED EMBEDDINGS (BGE-SMALL) - [KEEP THIS SAME]
26
  # ---------------------------------------------------------
27
  class OnnxBgeEmbeddings(Embeddings):
28
  def __init__(self):
 
 
29
  model_name = "Xenova/bge-small-en-v1.5"
30
  print(f"πŸ”„ Loading Embeddings: {model_name}...")
 
31
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
32
  self.model = ORTModelForFeatureExtraction.from_pretrained(
33
  model_name,
34
+ export=False,
35
+ provider=PROVIDERS[0]
36
  )
37
 
38
  def _process_batch(self, texts):
39
  inputs = self.tokenizer(texts, padding=True, truncation=True, max_length=512, return_tensors="pt")
 
40
  device = self.model.device
41
  inputs = {k: v.to(device) for k, v in inputs.items()}
 
42
  with torch.no_grad():
43
  outputs = self.model(**inputs)
 
44
  embeddings = outputs.last_hidden_state[:, 0]
45
  embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1)
46
  return embeddings.cpu().numpy().tolist()
 
52
  return self._process_batch(["Represent this sentence for searching relevant passages: " + text])[0]
53
 
54
  # ---------------------------------------------------------
55
+ # 2. OPTIMIZED LLM (Qwen 2.5 - 0.5B) - [FIXED]
56
  # ---------------------------------------------------------
57
  class LLMEvaluator:
58
  def __init__(self):
 
59
  self.repo_id = "onnx-community/Qwen2.5-0.5B-Instruct"
60
  self.local_dir = "onnx_qwen_local"
61
 
62
  print(f"πŸ”„ Preparing Ultra-Fast LLM: {self.repo_id}...")
63
 
64
+ # Download (same as before)
65
  if not os.path.exists(self.local_dir):
66
  print(f"πŸ“₯ Downloading FP16 model + data to {self.local_dir}...")
 
67
  snapshot_download(
68
  repo_id=self.repo_id,
69
  local_dir=self.local_dir,
 
73
 
74
  self.tokenizer = AutoTokenizer.from_pretrained(self.local_dir)
75
 
76
+ # --- CRITICAL FIX: DISABLE GRAPH OPTIMIZATIONS ---
77
+ # The model is already optimized. Re-optimizing it at runtime causes the crash.
78
+ sess_options = SessionOptions()
79
+ sess_options.graph_optimization_level = GraphOptimizationLevel.ORT_DISABLE_ALL
80
+
81
  self.model = ORTModelForCausalLM.from_pretrained(
82
  self.local_dir,
83
  subfolder="onnx",
84
  file_name="model_fp16.onnx",
85
  use_cache=True,
86
  use_io_binding=True,
87
+ provider=PROVIDERS[0],
88
+ session_options=sess_options # <--- PASS THIS HERE
89
  )
90
 
91
  def evaluate(self, context, question, student_answer, max_marks):