Spaces:
Runtime error
Runtime error
namberino
commited on
Commit
·
45332bf
1
Parent(s):
c7d9bea
Switch to cerebras api inference
Browse files- generator.py +7 -2
- utils.py +5 -4
generator.py
CHANGED
|
@@ -34,7 +34,7 @@ class RAGMCQ:
|
|
| 34 |
def __init__(
|
| 35 |
self,
|
| 36 |
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 37 |
-
hf_model: str = "
|
| 38 |
qdrant_url: str = None,
|
| 39 |
qdrant_api_key: str = None,
|
| 40 |
qdrant_prefer_grpc: bool = False,
|
|
@@ -216,7 +216,12 @@ class RAGMCQ:
|
|
| 216 |
seed_idx = random.randrange(len(self.texts))
|
| 217 |
chunk = self.texts[seed_idx]
|
| 218 |
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 219 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 220 |
query = f"Create questions about: {seed_sent}"
|
| 221 |
|
| 222 |
# retrieve top_k chunks
|
|
|
|
| 34 |
def __init__(
|
| 35 |
self,
|
| 36 |
embedder_model: str = "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2",
|
| 37 |
+
hf_model: str = "gpt-oss-120b",
|
| 38 |
qdrant_url: str = None,
|
| 39 |
qdrant_api_key: str = None,
|
| 40 |
qdrant_prefer_grpc: bool = False,
|
|
|
|
| 216 |
seed_idx = random.randrange(len(self.texts))
|
| 217 |
chunk = self.texts[seed_idx]
|
| 218 |
sents = re.split(r'(?<=[\.\?\!])\s+', chunk)
|
| 219 |
+
candidate = [s for s in sents if len(s.strip()) > 20]
|
| 220 |
+
if candidate:
|
| 221 |
+
seed_sent = random.choice(candidate)
|
| 222 |
+
else:
|
| 223 |
+
stripped = chunk.strip()
|
| 224 |
+
seed_sent = (stripped[:200] if stripped else "[no text available]")
|
| 225 |
query = f"Create questions about: {seed_sent}"
|
| 226 |
|
| 227 |
# retrieve top_k chunks
|
utils.py
CHANGED
|
@@ -4,9 +4,10 @@ from typing import Dict, Any
|
|
| 4 |
import requests
|
| 5 |
import os
|
| 6 |
|
| 7 |
-
API_URL = "https://
|
| 8 |
-
HF_KEY = os.environ['HF_API_KEY']
|
| 9 |
-
|
|
|
|
| 10 |
JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
|
| 11 |
|
| 12 |
def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
|
|
@@ -48,7 +49,7 @@ def _safe_extract_json(text: str) -> dict:
|
|
| 48 |
def generate_mcqs_from_text(
|
| 49 |
source_text: str,
|
| 50 |
n: int = 3,
|
| 51 |
-
model: str = "
|
| 52 |
temperature: float = 0.2,
|
| 53 |
) -> Dict[str, Any]:
|
| 54 |
system_message = {
|
|
|
|
| 4 |
import requests
|
| 5 |
import os
|
| 6 |
|
| 7 |
+
API_URL = "https://api.cerebras.ai/v1/chat/completions"
|
| 8 |
+
# HF_KEY = os.environ['HF_API_KEY']
|
| 9 |
+
CEREBRAS_API_KEY = os.environ['CEREBRAS_API_KEY']
|
| 10 |
+
HEADERS = {"Authorization": f"Bearer {CEREBRAS_API_KEY}", "Content-Type": "application/json"}
|
| 11 |
JSON_OBJ_RE = re.compile(r"(\{[\s\S]*\})", re.MULTILINE)
|
| 12 |
|
| 13 |
def _post_chat(messages: list, model: str, temperature: float = 0.2, timeout: int = 60) -> str:
|
|
|
|
| 49 |
def generate_mcqs_from_text(
|
| 50 |
source_text: str,
|
| 51 |
n: int = 3,
|
| 52 |
+
model: str = "gpt-oss-120b",
|
| 53 |
temperature: float = 0.2,
|
| 54 |
) -> Dict[str, Any]:
|
| 55 |
system_message = {
|