mindchain's picture
Upload app.py with huggingface_hub
cb6b39d verified
import os
import random
from typing import List, Optional
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from huggingface_hub import InferenceClient
from datasets import load_dataset
from models import (
GenerateRequest, GenerateResponse,
LoadSeedsRequest, LoadSeedsResponse,
HealthResponse, HFInferenceModel, HFInferenceProvider, HFJudgeModel,
GeneratedRecord
)
HF_TOKEN = os.environ.get("HF_TOKEN", "")
# Global seed storage
loaded_seeds: List[dict] = []
app = FastAPI(
title="Distilabel Generator API",
description="Phase 1: Generate synthetic data with GPT-OSS-120B via Cerebras",
version="1.0.0"
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
def get_client(model: HFInferenceModel, provider: HFInferenceProvider):
"""Create inference client."""
return InferenceClient(model=model.value, token=HF_TOKEN, provider=provider.value)
def get_judge_client(model: HFJudgeModel, provider: HFInferenceProvider):
"""Create judge inference client."""
return InferenceClient(model=model.value, token=HF_TOKEN, provider=provider.value)
# LLM Judge scoring prompt (UltraFeedback-style)
JUDGE_PROMPT = """Rate this Q&A pair quality (1-10 scale):
Question: {instruction}
Answer: {output}
Criteria:
- Instruction-following: Does the answer address the question?
- Truthfulness: Is the information accurate?
- Honesty: Does it avoid hallucination?
- Helpfulness: Is it useful for learning?
Output ONLY a single number 1-10:"""
def llm_judge_score(instruction: str, output: str, judge_model: HFJudgeModel, judge_provider: HFInferenceProvider) -> int:
"""Use LLM Judge to score quality (1-10 scale, UltraFeedback-style)."""
try:
client = get_judge_client(judge_model, judge_provider)
prompt = JUDGE_PROMPT.format(
instruction=instruction[:500],
output=output[:500]
)
result = client.chat_completion(
[{"role": "user", "content": prompt}],
max_tokens=10,
temperature=0.3,
)
content = result.choices[0].message.content
# Extract number from response
for char in content:
if char.isdigit():
score = int(char)
return min(max(score, 1), 10) # Clamp to 1-10
return 5
except Exception as e:
print(f"Judge Error: {e}")
return 5
def generate_record(seed_example: dict, model: HFInferenceModel, provider: HFInferenceProvider, temperature: float) -> Optional[GeneratedRecord]:
"""Generate a new record based on seed example."""
try:
client = get_client(model, provider)
prompt = f"""You are an expert at creating educational Q&A pairs. Generate a new Q&A pair in the same style and quality as this example:
Example:
Q: {seed_example.get('instruction', seed_example.get('question', 'What is AI?'))}
A: {seed_example.get('output', seed_example.get('answer', 'AI is...'))}
Generate a new Q&A pair on the topic: {seed_example.get('topic', 'machine learning')}
Difficulty level: {seed_example.get('difficulty', 'intermediate')}
Format your response as:
QUESTION: [your question]
ANSWER: [your detailed answer]"""
result = client.chat_completion(
[{"role": "user", "content": prompt}],
max_tokens=1024,
temperature=temperature,
)
content = result.choices[0].message.content
# Parse response
question, answer = "", ""
if "QUESTION:" in content and "ANSWER:" in content:
parts = content.split("ANSWER:")
answer = parts[1].strip() if len(parts) > 1 else ""
q_part = parts[0].split("QUESTION:")
question = q_part[1].strip() if len(q_part) > 1 else ""
if question and answer:
return GeneratedRecord(
instruction=question,
output=answer,
topic=seed_example.get("topic", "unknown"),
difficulty=seed_example.get("difficulty", "intermediate"),
quality_score=7,
model=model.value,
provider=provider.value
)
except Exception as e:
print(f"Generation error: {e}")
return None
@app.get("/", response_model=HealthResponse)
async def root():
return HealthResponse(
status="healthy",
model=HFInferenceModel.QWEN_7B.value,
provider=HFInferenceProvider.TOGETHER.value,
judge_model=HFJudgeModel.QWEN_7B.value,
judge_provider=HFInferenceProvider.TOGETHER.value,
seeds_loaded=len(loaded_seeds) > 0
)
@app.get("/health", response_model=HealthResponse)
async def health():
return HealthResponse(
status="healthy",
model=HFInferenceModel.QWEN_7B.value,
provider=HFInferenceProvider.TOGETHER.value,
judge_model=HFJudgeModel.QWEN_7B.value,
judge_provider=HFInferenceProvider.TOGETHER.value,
seeds_loaded=len(loaded_seeds) > 0
)
@app.post("/load-seeds", response_model=LoadSeedsResponse)
async def load_seeds(request: LoadSeedsRequest):
"""Load seed dataset from HF Hub."""
global loaded_seeds
try:
if not HF_TOKEN:
return LoadSeedsResponse(success=False, error="HF_TOKEN not configured")
dataset = load_dataset(request.seed_dataset, token=HF_TOKEN)
data = dataset["train"]
loaded_seeds = list(data)
return LoadSeedsResponse(
success=True,
loaded_count=len(loaded_seeds),
dataset_id=request.seed_dataset
)
except Exception as e:
return LoadSeedsResponse(success=False, error=str(e))
@app.post("/generate", response_model=GenerateResponse)
async def generate(request: GenerateRequest):
"""Generate synthetic data based on seed examples with LLM Judge scoring."""
global loaded_seeds
try:
if not HF_TOKEN:
return GenerateResponse(success=False, error="HF_TOKEN not configured")
# Load seeds if not already loaded or different dataset
if not loaded_seeds:
dataset = load_dataset(request.seed_dataset, token=HF_TOKEN)
loaded_seeds = list(dataset["train"])
if not loaded_seeds:
return GenerateResponse(success=False, error="No seed data available")
generated = []
for _ in range(request.num_records):
seed = random.choice(loaded_seeds)
record = generate_record(
seed,
request.model,
request.provider,
request.temperature
)
if record:
# Apply LLM Judge scoring if enabled
if request.use_judge:
quality_score = llm_judge_score(
record.instruction,
record.output,
request.judge_model,
request.judge_provider
)
record.quality_score = quality_score
generated.append(record)
return GenerateResponse(
success=True,
data=generated,
record_count=len(generated),
seeds_used=len(loaded_seeds)
)
except Exception as e:
return GenerateResponse(success=False, error=str(e))
@app.get("/models")
async def list_models():
"""List available generation and judge models."""
return {
"generator_models": [
{
"id": "Qwen/Qwen2.5-72B-Instruct",
"name": "Qwen2.5-72B",
"provider": "cerebras",
"description": "High-quality 72B model via Cerebras"
},
{
"id": "openai/gpt-oss-120b",
"name": "GPT-OSS-120B",
"provider": "cerebras",
"description": "Large powerful model via Cerebras"
}
],
"judge_models": [
{
"id": "Qwen/Qwen2.5-7B-Instruct",
"name": "Qwen2.5-7B",
"provider": "together",
"description": "Fast judge model via Together (UltraFeedback-style)"
},
{
"id": "z-ai/glm-5",
"name": "GLM-5",
"provider": "together",
"description": "High-quality judge via Together"
}
]
}