namberino commited on
Commit
45332bf
·
1 Parent(s): c7d9bea

Switch to cerebras api inference

Browse files
Files changed (2) hide show
  1. generator.py +7 -2
  2. utils.py +5 -4
generator.py CHANGED
@@ -34,7 +34,7 @@ class RAGMCQ:
34
  def __init__(
35
  self,
36
  embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
37
- hf_model: str = "openai/gpt-oss-120b:cerebras",
38
  qdrant_url: str = None,
39
  qdrant_api_key: str = None,
40
  qdrant_prefer_grpc: bool = False,
@@ -216,7 +216,12 @@ class RAGMCQ:
216
  seed_idx = random.randrange(len(self.texts))
217
  chunk = self.texts[seed_idx]
218
  sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
219
- seed_sent = random.choice([s for s in sents if len(s.strip()) > 20]) if sents else chunk[:200]
 
 
 
 
 
220
  query = f"Create questions about: {seed_sent}"
221
 
222
  # retrieve top_k chunks
 
34
  def __init__(
35
  self,
36
  embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
37
+ hf_model: str = "gpt-oss-120b",
38
  qdrant_url: str = None,
39
  qdrant_api_key: str = None,
40
  qdrant_prefer_grpc: bool = False,
 
216
  seed_idx = random.randrange(len(self.texts))
217
  chunk = self.texts[seed_idx]
218
  sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
219
+ candidate = [s for s in sents if len(s.strip()) > 20]
220
+ if candidate:
221
+ seed_sent = random.choice(candidate)
222
+ else:
223
+ stripped = chunk.strip()
224
+ seed_sent = (stripped[:200] if stripped else "[no text available]")
225
  query = f"Create questions about: {seed_sent}"
226
 
227
  # retrieve top_k chunks
utils.py CHANGED
@@ -4,9 +4,10 @@ from typing import Dict, Any
4
  import requests
5
  import os
6
 
7
- API_URL = "https://router.huggingface.co/v1/chat/completions"
8
- HF_KEY = os.environ['HF_API_KEY']
9
- HEADERS = {"Authorization": f"Bearer {HF_KEY}"}
 
10
  JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
11
 
12
  def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
@@ -48,7 +49,7 @@ def _safe_extract_json(text: str) -> dict:
48
  def generate_mcqs_from_text(
49
  source_text: str,
50
  n: int = 3,
51
- model: str = "openai/gpt-oss-120b:cerebras",
52
  temperature: float = 0.2,
53
  ) -> Dict[str, Any]:
54
  system_message = {
 
4
  import requests
5
  import os
6
 
7
+ API_URL = "https://api.cerebras.ai/v1/chat/completions"
8
+ # HF_KEY = os.environ['HF_API_KEY']
9
+ CEREBRAS_API_KEY = os.environ['CEREBRAS_API_KEY']
10
+ HEADERS = {"Authorization": f"Bearer {CEREBRAS_API_KEY}", "Content-Type": "application/json"}
11
  JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
12
 
13
  def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
 
49
  def generate_mcqs_from_text(
50
  source_text: str,
51
  n: int = 3,
52
+ model: str = "gpt-oss-120b",
53
  temperature: float = 0.2,
54
  ) -> Dict[str, Any]:
55
  system_message = {