BoostedJonP's picture
cpu quantiziation config
9939aee
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
from functools import lru_cache
import logging
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
logger.info("Starting Jerome Powell AI Assistant...")
MODEL_NAME = "BoostedJonP/powell-phi3-mini"
@lru_cache(maxsize=1)
def load_model():
"""Load the fine-tuned Jerome Powell model"""
logger.info(f"Loading model: {MODEL_NAME}")
try:
config = AutoConfig.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
cache_dir="/tmp/model_cache",
)
tokenizer = AutoTokenizer.from_pretrained(
MODEL_NAME,
trust_remote_code=True,
cache_dir="/tmp/model_cache",
)
# Check if CUDA is available, otherwise use CPU-friendly settings
if torch.cuda.is_available():
logger.info("CUDA available, loading with GPU optimizations")
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
config=config,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="auto",
attn_implementation="eager",
cache_dir="/tmp/model_cache",
)
else:
logger.info("CUDA not available, loading with CPU optimizations")
cpu_config = config
if hasattr(cpu_config, "quantization_config"):
logger.info("Removing quantization config for CPU compatibility")
config_dict = cpu_config.to_dict()
config_dict.pop("quantization_config", None)
cpu_config = cpu_config.__class__(**config_dict)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
config=cpu_config,
trust_remote_code=True,
torch_dtype=torch.float16,
attn_implementation="eager",
cache_dir="/tmp/model_cache",
low_cpu_mem_usage=True,
)
# Move model to CPU manually
model = model.to(device="cpu", dtype=torch.float32)
logger.info("Model loaded successfully!")
logger.info(f"Model device: {next(model.parameters()).device}")
logger.info(f"Model dtype: {next(model.parameters()).dtype}")
except Exception as e:
logger.error(f"Error loading model: {e}")
logger.error(f"Error type: {type(e).__name__}")
return None, None
model.generation_config.use_cache = True
model.generation_config.pad_token_id = tokenizer.eos_token_id
# Only compile on GPU, skip compilation on CPU to avoid compatibility issues
if torch.cuda.is_available():
try:
model = torch.compile(model, mode="reduce-overhead")
logger.info("Model compiled successfully")
except Exception as e:
logger.warning(f"Model compilation failed: {e}, continuing without compilation")
return model, tokenizer
model, tokenizer = load_model()
def generate_powell_response(question, max_length=256, num_beams=1, temperature=0.3):
"""Generate a response in Jerome Powell's style"""
if model is None or tokenizer is None:
return "❌ Model failed to load. Please refresh the page and try again."
if not question.strip():
return (
"Please ask a question about monetary policy, economics, or Federal Reserve operations."
)
# Log device information for debugging
device = next(model.parameters()).device
logger.info(f"Generating response on device: {device}")
system_prompt = """You are Jerome Powell, the Chairman of the Federal Reserve."""
prompt = f"System: {system_prompt}\n\nQuestion: {question.strip()}\nAnswer:"
try:
inputs = tokenizer(
prompt,
return_tensors="pt",
truncation=True,
max_length=max_length,
padding=False,
)
# Move inputs to the same device as the model
if torch.cuda.is_available() and next(model.parameters()).is_cuda:
inputs = {k: v.cuda() for k, v in inputs.items()}
else:
inputs = {k: v.cpu() for k, v in inputs.items()}
with torch.no_grad():
generation_config = {
"max_new_tokens": max_length,
"num_beams": num_beams,
"early_stopping": True,
"do_sample": True,
"temperature": temperature,
"repetition_penalty": 1.1,
"use_cache": True,
"output_scores": False,
"return_dict_in_generate": False,
}
outputs = model.generate(**inputs, **generation_config)
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
if "Answer:" in response:
answer = response.split("Answer:")[-1].strip()
else:
answer = response[len(prompt) :].strip()
return (
answer
if answer
else "I apologize, but I couldn't generate a proper response. Please try rephrasing your question."
)
except Exception as e:
logger.error(f"Error generating response: {e}")
return f"❌ Error generating response: {str(e)}"
custom_css = """
.gradio-container {
font-family: 'Segoe UI', Tahoma, Geneva, Verdana, sans-serif;
max-width: 1200px;
margin: 0 auto;
}
.header-text {
text-align: center;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
-webkit-background-clip: text;
-webkit-text-fill-color: transparent;
background-clip: text;
font-size: 2.5rem;
font-weight: bold;
margin-bottom: 0.5rem;
}
.subtitle-text {
text-align: center;
color: #666;
font-size: 1.2rem;
margin-bottom: 2rem;
}
.example-box {
background: #f8f9fa;
padding: 1rem;
border-radius: 8px;
border-left: 4px solid #667eea;
margin: 1rem 0;
}
.footer-text {
text-align: center;
color: #666;
font-size: 0.9rem;
margin-top: 2rem;
padding: 1rem;
border-top: 1px solid #eee;
}
/* Make buttons more prominent */
.primary-button {
background: linear-gradient(45deg, #667eea, #764ba2) !important;
border: none !important;
color: white !important;
}
/* Response styling */
.response-box {
background: #f8f9fa;
padding: 1.5rem;
border-radius: 12px;
border-left: 4px solid #28a745;
font-size: 1.1rem;
line-height: 1.6;
}
"""
def create_interface():
with gr.Blocks(
css=custom_css,
title="Jerome Powell AI Assistant | Federal Reserve Q&A",
theme=gr.themes.Soft(),
) as demo:
gr.HTML(
"""
<div style="display: flex; align-items: center; justify-content: center; gap: 2rem; margin-bottom: 2rem;">
<div style="flex: 1; text-align: center;">
<div class="header-text">🏦 Jerome Powell AI Assistant</div>
<div class="subtitle-text">
Fine-tuned Phi3-Mini model trained on Federal Reserve Chairman Jerome Powell's Q&A sessions
</div>
</div>
<div style="flex-shrink: 0;">
<img src="https://storage.googleapis.com/kaggle-datasets-images/8130068/12853913/f5d1487cf839f69edc1fcde3d30a583a/dataset-cover.jpeg?t=2025-08-24-13-28-39"
alt="Jerome Powell"
style="width: 200px; height: 200px; border-radius: 50%; object-fit: cover; border: 4px solid #667eea; box-shadow: 0 4px 8px rgba(0,0,0,0.1);">
</div>
</div>
"""
)
with gr.Row():
with gr.Column(scale=2):
question_input = gr.Textbox(
label="πŸ’¬ Ask about monetary policy, economics, or Federal Reserve operations",
placeholder="e.g., What factors influence Federal Reserve interest rate decisions?",
lines=3,
max_lines=5,
)
with gr.Row():
submit_btn = gr.Button("🎯 Ask Jerome Powell AI", variant="primary", scale=2)
clear_btn = gr.Button("πŸ”„ Clear", scale=1)
with gr.Accordion("βš™οΈ Advanced Settings", open=False):
with gr.Row():
max_length = gr.Slider(
minimum=64,
maximum=512,
value=256,
step=32,
label="Max Response Length",
info="Longer responses may be more detailed but take more time",
)
num_beams = gr.Slider(
minimum=1,
maximum=8,
value=1,
step=1,
label="Number of Beams",
info="Higher values = better quality but slower generation (beam search)",
)
temperature = gr.Slider(
minimum=0.1,
maximum=2.0,
value=0.3,
step=0.1,
label="Temperature",
info="Higher values = more creative/random responses (affects beam search diversity)",
)
response_output = gr.Textbox(
label="πŸ’Ό Jerome Powell AI Response",
lines=8,
max_lines=15,
show_copy_button=True,
container=True,
)
gr.HTML(
"""
<div class="footer-text">
<h3>πŸ“Š Model Information</h3>
<p>
<strong>Base Model:</strong> Microsoft Phi3-Mini<br>
<strong>Fine-tuning:</strong> Specialized on Jerome Powell Q&A data<br>
<strong>Model Hub:</strong> <a href="https://huggingface.co/BoostedJonP/powell-phi3-mini" target="_blank">BoostedJonP/powell-phi3-mini</a><br>
<strong>Dataset:</strong> <a href="https://huggingface.co/datasets/BoostedJonP/JeromePowell-SFT" target="_blank">BoostedJonP/JeromePowell-SFT</a><br>
<strong>Repo:</strong> <a href="https://github.com/BigJonP/powell-phi3-sft" target="_blank">BigJonP/powell-phi3-sft</a><br>
<strong>Author:</strong> <a href="https://github.com/BigJonPP" target="_blank">Jonathan Paserman</a>
</p>
<p><em>⚠️ Disclaimer: This AI model provides educational insights based on training data and should not be considered as official Federal Reserve communication or financial advice. Always consult official Fed sources for authoritative information.</em></p>
</div>
"""
)
submit_btn.click(
fn=generate_powell_response,
inputs=[question_input, max_length, num_beams, temperature],
outputs=response_output,
show_progress=True,
)
clear_btn.click(lambda: ("", ""), outputs=[question_input, response_output])
question_input.submit(
fn=generate_powell_response,
inputs=[question_input, max_length, num_beams, temperature],
outputs=response_output,
show_progress=True,
)
return demo
demo = create_interface()
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
show_error=True,
)