| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| from typing import List, Optional |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments |
| from peft import LoraConfig, get_peft_model, TaskType |
| from torch.utils.data import Dataset |
| from datasets import load_dataset |
| import gradio as gr |
| import numpy as np |
| import logging |
| import os |
| from datetime import datetime |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| |
| app = FastAPI(title="Jain Framework AI API", version="1.0.0") |
|
|
| |
| MODEL_NAME = "gpt2-medium" |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
| tokenizer.pad_token = tokenizer.eos_token |
| model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE) |
| except Exception as e: |
| logger.error(f"Failed to load model or tokenizer: {str(e)}") |
| raise |
|
|
| |
| lora_config = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| r=16, |
| lora_alpha=32, |
| lora_dropout=0.1, |
| target_modules=["c_attn", "c_proj"] |
| ) |
| model = get_peft_model(model, lora_config) |
| model.to(DEVICE) |
|
|
| |
| class PhilosophicalDialogueDataset(Dataset): |
| def __init__(self, dataset, tokenizer, max_length=512): |
| self.tokenizer = tokenizer |
| self.max_length = max_length |
| self.texts = [example["text"] for example in dataset] |
| |
| def __len__(self): |
| return len(self.texts) |
| |
| def __getitem__(self, idx): |
| text = self.texts[idx] |
| encoding = self.tokenizer( |
| text, |
| truncation=True, |
| padding="max_length", |
| max_length=self.max_length, |
| return_tensors="pt" |
| ) |
| return { |
| "input_ids": encoding["input_ids"].squeeze(), |
| "attention_mask": encoding["attention_mask"].squeeze(), |
| "labels": encoding["input_ids"].squeeze() |
| } |
|
|
| |
| def load_philosophical_dataset(): |
| try: |
| |
| |
| dataset = [{"text": "What is the nature of existence?"}] * 100 |
| return PhilosophicalDialogueDataset(dataset, tokenizer) |
| except Exception as e: |
| logger.error(f"Failed to load dataset: {str(e)}") |
| raise |
|
|
| |
| def setup_trainer(dataset): |
| training_args = TrainingArguments( |
| output_dir="./model_checkpoints", |
| num_train_epochs=3, |
| per_device_train_batch_size=4, |
| gradient_accumulation_steps=4, |
| learning_rate=5e-5, |
| warmup_steps=500, |
| logging_steps=10, |
| save_steps=500, |
| save_total_limit=2, |
| fp16=torch.cuda.is_available(), |
| report_to="none" |
| ) |
| |
| trainer = Trainer( |
| model=model, |
| args=training_args, |
| train_dataset=dataset |
| ) |
| return trainer |
|
|
| |
| class GenerationRequest(BaseModel): |
| prompt: str |
| max_length: Optional[int] = 100 |
| temperature: Optional[float] = 0.8 |
| top_k: Optional[int] = 50 |
| top_p: Optional[float] = 0.9 |
|
|
| class GenerationResponse(BaseModel): |
| generated_text: str |
| prompt: str |
| generation_time: float |
|
|
| class BatchGenerationRequest(BaseModel): |
| prompts: List[str] |
| max_length: Optional[int] = 100 |
| temperature: Optional[float] = 0.8 |
| top_k: Optional[int] = 50 |
| top_p: Optional[float] = 0.9 |
|
|
| |
| class OptimizedGenerator: |
| def __init__(self, model, tokenizer): |
| self.model = model |
| self.tokenizer = tokenizer |
| self.device = next(model.parameters()).device |
| self.model.eval() |
| |
| @torch.no_grad() |
| def generate(self, prompt, max_length=100, temperature=0.8, top_k=50, top_p=0.9): |
| inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device) |
| outputs = self.model.generate( |
| inputs, |
| max_length=max_length, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| no_repeat_ngram_size=2 |
| ) |
| return self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| @torch.no_grad() |
| def generate_batch(self, prompts, max_length=100, temperature=0.8, top_k=50, top_p=0.9): |
| encoded = self.tokenizer( |
| prompts, |
| return_tensors="pt", |
| padding=True, |
| truncation=True |
| ).to(self.device) |
| |
| outputs = self.model.generate( |
| input_ids=encoded["input_ids"], |
| attention_mask=encoded["attention_mask"], |
| max_length=max_length, |
| temperature=temperature, |
| top_k=top_k, |
| top_p=top_p, |
| do_sample=True, |
| pad_token_id=self.tokenizer.eos_token_id, |
| no_repeat_ngram_size=2 |
| ) |
| |
| return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs] |
|
|
| |
| generator = OptimizedGenerator(model, tokenizer) |
|
|
| |
| @app.post("/generate", response_model=GenerationResponse) |
| async def generate_text(request: GenerationRequest): |
| try: |
| start_time = datetime.now() |
| generated_text = generator.generate( |
| request.prompt, |
| max_length=request.max_length, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p |
| ) |
| generation_time = (datetime.now() - start_time).total_seconds() |
| return GenerationResponse( |
| generated_text=generated_text, |
| prompt=request.prompt, |
| generation_time=generation_time |
| ) |
| except Exception as e: |
| logger.error(f"Generation failed: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.post("/generate_batch") |
| async def generate_batch(request: BatchGenerationRequest): |
| try: |
| generated_texts = generator.generate_batch( |
| request.prompts, |
| max_length=request.max_length, |
| temperature=request.temperature, |
| top_k=request.top_k, |
| top_p=request.top_p |
| ) |
| return { |
| "results": [ |
| {"prompt": prompt, "generated": generated} |
| for prompt, generated in zip(request.prompts, generated_texts) |
| ] |
| } |
| except Exception as e: |
| logger.error(f"Batch generation failed: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| @app.get("/health") |
| async def health_check(): |
| return {"status": "healthy", "model_loaded": True} |
|
|
| |
| def gradio_interface(prompt, max_length=100, temperature=0.8): |
| try: |
| generated = generator.generate(prompt, max_length, temperature) |
| return generated |
| except Exception as e: |
| logger.error(f"Gradio generation failed: {str(e)}") |
| return f"Error: {str(e)}" |
|
|
| |
| @app.post("/train") |
| async def train_model(): |
| try: |
| dataset = load_philosophical_dataset() |
| trainer = setup_trainer(dataset) |
| trainer.train() |
| trainer.save_model("./final_model") |
| tokenizer.save_pretrained("./final_model") |
| return {"status": "Training completed", "model_path": "./final_model"} |
| except Exception as e: |
| logger.error(f"Training failed: {str(e)}") |
| raise HTTPException(status_code=500, detail=str(e)) |
|
|
| |
| gr_iface = gr.Interface( |
| fn=gradio_interface, |
| inputs=[ |
| gr.Textbox(lines=2, placeholder="Enter your philosophical question here..."), |
| gr.Slider(minimum=50, maximum=500, value=100, label="Max Length"), |
| gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature") |
| ], |
| outputs="text", |
| title="Jain Framework: Philosophical AI Dialogue", |
| description="Interact with an AI grounded in the Jain Framework, blending Eastern philosophy with advanced NLP." |
| ) |
|
|
| |
| if __name__ == "__main__": |
| import uvicorn |
| import threading |
| |
| |
| def run_gradio(): |
| gr_iface.launch(share=False, server_name="0.0.0.0", server_port=7860) |
| |
| gradio_thread = threading.Thread(target=run_gradio) |
| gradio_thread.daemon = True |
| gradio_thread.start() |
| |
| |
| uvicorn.run(app, host="0.0.0.0", port=8000) |