Bonnie / app.py
ruddnjsfk's picture
Update app.py
cf64ae6 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import List, Optional
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, Trainer, TrainingArguments
from peft import LoraConfig, get_peft_model, TaskType
from torch.utils.data import Dataset
from datasets import load_dataset
import gradio as gr
import numpy as np
import logging
import os
from datetime import datetime
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI(title="Jain Framework AI API", version="1.0.0")
# Model and tokenizer setup
MODEL_NAME = "gpt2-medium"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
try:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
except Exception as e:
logger.error(f"Failed to load model or tokenizer: {str(e)}")
raise
# Apply LoRA for parameter-efficient fine-tuning
lora_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=16,
lora_alpha=32,
lora_dropout=0.1,
target_modules=["c_attn", "c_proj"]
)
model = get_peft_model(model, lora_config)
model.to(DEVICE)
# Custom Dataset for philosophical dialogue
class PhilosophicalDialogueDataset(Dataset):
def __init__(self, dataset, tokenizer, max_length=512):
self.tokenizer = tokenizer
self.max_length = max_length
self.texts = [example["text"] for example in dataset]
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = self.texts[idx]
encoding = self.tokenizer(
text,
truncation=True,
padding="max_length",
max_length=self.max_length,
return_tensors="pt"
)
return {
"input_ids": encoding["input_ids"].squeeze(),
"attention_mask": encoding["attention_mask"].squeeze(),
"labels": encoding["input_ids"].squeeze()
}
# Load and prepare dataset
def load_philosophical_dataset():
try:
# Placeholder: Replace with actual dataset from Hugging Face or local source
# Example: dataset = load_dataset("path/to/your/philosophical_dialogue_dataset")
dataset = [{"text": "What is the nature of existence?"}] * 100 # Mock for demonstration
return PhilosophicalDialogueDataset(dataset, tokenizer)
except Exception as e:
logger.error(f"Failed to load dataset: {str(e)}")
raise
# Training setup
def setup_trainer(dataset):
training_args = TrainingArguments(
output_dir="./model_checkpoints",
num_train_epochs=3,
per_device_train_batch_size=4,
gradient_accumulation_steps=4,
learning_rate=5e-5,
warmup_steps=500,
logging_steps=10,
save_steps=500,
save_total_limit=2,
fp16=torch.cuda.is_available(),
report_to="none"
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset
)
return trainer
# Pydantic models for API
class GenerationRequest(BaseModel):
prompt: str
max_length: Optional[int] = 100
temperature: Optional[float] = 0.8
top_k: Optional[int] = 50
top_p: Optional[float] = 0.9
class GenerationResponse(BaseModel):
generated_text: str
prompt: str
generation_time: float
class BatchGenerationRequest(BaseModel):
prompts: List[str]
max_length: Optional[int] = 100
temperature: Optional[float] = 0.8
top_k: Optional[int] = 50
top_p: Optional[float] = 0.9
# Optimized Generator
class OptimizedGenerator:
def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer
self.device = next(model.parameters()).device
self.model.eval()
@torch.no_grad()
def generate(self, prompt, max_length=100, temperature=0.8, top_k=50, top_p=0.9):
inputs = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
outputs = self.model.generate(
inputs,
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=2
)
return self.tokenizer.decode(outputs[0], skip_special_tokens=True)
@torch.no_grad()
def generate_batch(self, prompts, max_length=100, temperature=0.8, top_k=50, top_p=0.9):
encoded = self.tokenizer(
prompts,
return_tensors="pt",
padding=True,
truncation=True
).to(self.device)
outputs = self.model.generate(
input_ids=encoded["input_ids"],
attention_mask=encoded["attention_mask"],
max_length=max_length,
temperature=temperature,
top_k=top_k,
top_p=top_p,
do_sample=True,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=2
)
return [self.tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
# Initialize generator
generator = OptimizedGenerator(model, tokenizer)
# API Routes
@app.post("/generate", response_model=GenerationResponse)
async def generate_text(request: GenerationRequest):
try:
start_time = datetime.now()
generated_text = generator.generate(
request.prompt,
max_length=request.max_length,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p
)
generation_time = (datetime.now() - start_time).total_seconds()
return GenerationResponse(
generated_text=generated_text,
prompt=request.prompt,
generation_time=generation_time
)
except Exception as e:
logger.error(f"Generation failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/generate_batch")
async def generate_batch(request: BatchGenerationRequest):
try:
generated_texts = generator.generate_batch(
request.prompts,
max_length=request.max_length,
temperature=request.temperature,
top_k=request.top_k,
top_p=request.top_p
)
return {
"results": [
{"prompt": prompt, "generated": generated}
for prompt, generated in zip(request.prompts, generated_texts)
]
}
except Exception as e:
logger.error(f"Batch generation failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": True}
# Gradio Interface
def gradio_interface(prompt, max_length=100, temperature=0.8):
try:
generated = generator.generate(prompt, max_length, temperature)
return generated
except Exception as e:
logger.error(f"Gradio generation failed: {str(e)}")
return f"Error: {str(e)}"
# Training endpoint
@app.post("/train")
async def train_model():
try:
dataset = load_philosophical_dataset()
trainer = setup_trainer(dataset)
trainer.train()
trainer.save_model("./final_model")
tokenizer.save_pretrained("./final_model")
return {"status": "Training completed", "model_path": "./final_model"}
except Exception as e:
logger.error(f"Training failed: {str(e)}")
raise HTTPException(status_code=500, detail=str(e))
# Gradio app
gr_iface = gr.Interface(
fn=gradio_interface,
inputs=[
gr.Textbox(lines=2, placeholder="Enter your philosophical question here..."),
gr.Slider(minimum=50, maximum=500, value=100, label="Max Length"),
gr.Slider(minimum=0.1, maximum=2.0, value=0.8, label="Temperature")
],
outputs="text",
title="Jain Framework: Philosophical AI Dialogue",
description="Interact with an AI grounded in the Jain Framework, blending Eastern philosophy with advanced NLP."
)
# Run both FastAPI and Gradio
if __name__ == "__main__":
import uvicorn
import threading
# Run Gradio in a separate thread
def run_gradio():
gr_iface.launch(share=False, server_name="0.0.0.0", server_port=7860)
gradio_thread = threading.Thread(target=run_gradio)
gradio_thread.daemon = True
gradio_thread.start()
# Run FastAPI
uvicorn.run(app, host="0.0.0.0", port=8000)