Spaces:
Sleeping
Sleeping
| import os | |
| import random | |
| from typing import List, Optional | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from huggingface_hub import InferenceClient | |
| from datasets import load_dataset | |
| from models import ( | |
| GenerateRequest, GenerateResponse, | |
| LoadSeedsRequest, LoadSeedsResponse, | |
| HealthResponse, HFInferenceModel, HFInferenceProvider, HFJudgeModel, | |
| GeneratedRecord | |
| ) | |
| HF_TOKEN = os.environ.get("HF_TOKEN", "") | |
| # Global seed storage | |
| loaded_seeds: List[dict] = [] | |
| app = FastAPI( | |
| title="Distilabel Generator API", | |
| description="Phase 1: Generate synthetic data with GPT-OSS-120B via Cerebras", | |
| version="1.0.0" | |
| ) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| def get_client(model: HFInferenceModel, provider: HFInferenceProvider): | |
| """Create inference client.""" | |
| return InferenceClient(model=model.value, token=HF_TOKEN, provider=provider.value) | |
| def get_judge_client(model: HFJudgeModel, provider: HFInferenceProvider): | |
| """Create judge inference client.""" | |
| return InferenceClient(model=model.value, token=HF_TOKEN, provider=provider.value) | |
| # LLM Judge scoring prompt (UltraFeedback-style) | |
| JUDGE_PROMPT = """Rate this Q&A pair quality (1-10 scale): | |
| Question: {instruction} | |
| Answer: {output} | |
| Criteria: | |
| - Instruction-following: Does the answer address the question? | |
| - Truthfulness: Is the information accurate? | |
| - Honesty: Does it avoid hallucination? | |
| - Helpfulness: Is it useful for learning? | |
| Output ONLY a single number 1-10:""" | |
| def llm_judge_score(instruction: str, output: str, judge_model: HFJudgeModel, judge_provider: HFInferenceProvider) -> int: | |
| """Use LLM Judge to score quality (1-10 scale, UltraFeedback-style).""" | |
| try: | |
| client = get_judge_client(judge_model, judge_provider) | |
| prompt = JUDGE_PROMPT.format( | |
| instruction=instruction[:500], | |
| output=output[:500] | |
| ) | |
| result = client.chat_completion( | |
| [{"role": "user", "content": prompt}], | |
| max_tokens=10, | |
| temperature=0.3, | |
| ) | |
| content = result.choices[0].message.content | |
| # Extract number from response | |
| for char in content: | |
| if char.isdigit(): | |
| score = int(char) | |
| return min(max(score, 1), 10) # Clamp to 1-10 | |
| return 5 | |
| except Exception as e: | |
| print(f"Judge Error: {e}") | |
| return 5 | |
| def generate_record(seed_example: dict, model: HFInferenceModel, provider: HFInferenceProvider, temperature: float) -> Optional[GeneratedRecord]: | |
| """Generate a new record based on seed example.""" | |
| try: | |
| client = get_client(model, provider) | |
| prompt = f"""You are an expert at creating educational Q&A pairs. Generate a new Q&A pair in the same style and quality as this example: | |
| Example: | |
| Q: {seed_example.get('instruction', seed_example.get('question', 'What is AI?'))} | |
| A: {seed_example.get('output', seed_example.get('answer', 'AI is...'))} | |
| Generate a new Q&A pair on the topic: {seed_example.get('topic', 'machine learning')} | |
| Difficulty level: {seed_example.get('difficulty', 'intermediate')} | |
| Format your response as: | |
| QUESTION: [your question] | |
| ANSWER: [your detailed answer]""" | |
| result = client.chat_completion( | |
| [{"role": "user", "content": prompt}], | |
| max_tokens=1024, | |
| temperature=temperature, | |
| ) | |
| content = result.choices[0].message.content | |
| # Parse response | |
| question, answer = "", "" | |
| if "QUESTION:" in content and "ANSWER:" in content: | |
| parts = content.split("ANSWER:") | |
| answer = parts[1].strip() if len(parts) > 1 else "" | |
| q_part = parts[0].split("QUESTION:") | |
| question = q_part[1].strip() if len(q_part) > 1 else "" | |
| if question and answer: | |
| return GeneratedRecord( | |
| instruction=question, | |
| output=answer, | |
| topic=seed_example.get("topic", "unknown"), | |
| difficulty=seed_example.get("difficulty", "intermediate"), | |
| quality_score=7, | |
| model=model.value, | |
| provider=provider.value | |
| ) | |
| except Exception as e: | |
| print(f"Generation error: {e}") | |
| return None | |
| async def root(): | |
| return HealthResponse( | |
| status="healthy", | |
| model=HFInferenceModel.QWEN_7B.value, | |
| provider=HFInferenceProvider.TOGETHER.value, | |
| judge_model=HFJudgeModel.QWEN_7B.value, | |
| judge_provider=HFInferenceProvider.TOGETHER.value, | |
| seeds_loaded=len(loaded_seeds) > 0 | |
| ) | |
| async def health(): | |
| return HealthResponse( | |
| status="healthy", | |
| model=HFInferenceModel.QWEN_7B.value, | |
| provider=HFInferenceProvider.TOGETHER.value, | |
| judge_model=HFJudgeModel.QWEN_7B.value, | |
| judge_provider=HFInferenceProvider.TOGETHER.value, | |
| seeds_loaded=len(loaded_seeds) > 0 | |
| ) | |
| async def load_seeds(request: LoadSeedsRequest): | |
| """Load seed dataset from HF Hub.""" | |
| global loaded_seeds | |
| try: | |
| if not HF_TOKEN: | |
| return LoadSeedsResponse(success=False, error="HF_TOKEN not configured") | |
| dataset = load_dataset(request.seed_dataset, token=HF_TOKEN) | |
| data = dataset["train"] | |
| loaded_seeds = list(data) | |
| return LoadSeedsResponse( | |
| success=True, | |
| loaded_count=len(loaded_seeds), | |
| dataset_id=request.seed_dataset | |
| ) | |
| except Exception as e: | |
| return LoadSeedsResponse(success=False, error=str(e)) | |
| async def generate(request: GenerateRequest): | |
| """Generate synthetic data based on seed examples with LLM Judge scoring.""" | |
| global loaded_seeds | |
| try: | |
| if not HF_TOKEN: | |
| return GenerateResponse(success=False, error="HF_TOKEN not configured") | |
| # Load seeds if not already loaded or different dataset | |
| if not loaded_seeds: | |
| dataset = load_dataset(request.seed_dataset, token=HF_TOKEN) | |
| loaded_seeds = list(dataset["train"]) | |
| if not loaded_seeds: | |
| return GenerateResponse(success=False, error="No seed data available") | |
| generated = [] | |
| for _ in range(request.num_records): | |
| seed = random.choice(loaded_seeds) | |
| record = generate_record( | |
| seed, | |
| request.model, | |
| request.provider, | |
| request.temperature | |
| ) | |
| if record: | |
| # Apply LLM Judge scoring if enabled | |
| if request.use_judge: | |
| quality_score = llm_judge_score( | |
| record.instruction, | |
| record.output, | |
| request.judge_model, | |
| request.judge_provider | |
| ) | |
| record.quality_score = quality_score | |
| generated.append(record) | |
| return GenerateResponse( | |
| success=True, | |
| data=generated, | |
| record_count=len(generated), | |
| seeds_used=len(loaded_seeds) | |
| ) | |
| except Exception as e: | |
| return GenerateResponse(success=False, error=str(e)) | |
| async def list_models(): | |
| """List available generation and judge models.""" | |
| return { | |
| "generator_models": [ | |
| { | |
| "id": "Qwen/Qwen2.5-72B-Instruct", | |
| "name": "Qwen2.5-72B", | |
| "provider": "cerebras", | |
| "description": "High-quality 72B model via Cerebras" | |
| }, | |
| { | |
| "id": "openai/gpt-oss-120b", | |
| "name": "GPT-OSS-120B", | |
| "provider": "cerebras", | |
| "description": "Large powerful model via Cerebras" | |
| } | |
| ], | |
| "judge_models": [ | |
| { | |
| "id": "Qwen/Qwen2.5-7B-Instruct", | |
| "name": "Qwen2.5-7B", | |
| "provider": "together", | |
| "description": "Fast judge model via Together (UltraFeedback-style)" | |
| }, | |
| { | |
| "id": "z-ai/glm-5", | |
| "name": "GLM-5", | |
| "provider": "together", | |
| "description": "High-quality judge via Together" | |
| } | |
| ] | |
| } | |