rui3000's picture
Update RockPaperScissor/services/LLM_service.py
6032be5 verified
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from typing import List, Dict, Optional
import re
import os
# Global model variables
_model = None
_tokenizer = None
_model_name = "Qwen/Qwen3-0.6B"
def initialize_tokenizer():
"""Initialize tokenizer globally"""
global _tokenizer
if _tokenizer is None:
print("[LLMService] Loading tokenizer from Hugging Face Hub...")
_tokenizer = AutoTokenizer.from_pretrained(_model_name)
if _tokenizer.pad_token is None:
_tokenizer.pad_token = _tokenizer.eos_token
print("[LLMService] Tokenizer loaded successfully.")
return _tokenizer
def generate_text(prompt_text: str, max_tokens: int = 150):
"""Generate text using the model"""
global _model, _tokenizer
print("[LLMService] Starting text generation...")
# Initialize tokenizer if needed
if _tokenizer is None:
initialize_tokenizer()
# Load model if not already loaded
if _model is None:
print("[LLMService] Loading model...")
try:
_model = AutoModelForCausalLM.from_pretrained(
_model_name,
torch_dtype=torch.float32,
device_map="cpu",
trust_remote_code=True,
low_cpu_mem_usage=True
)
print("[LLMService] Model loaded successfully.")
except Exception as e:
print(f"[LLMService] Error loading model: {e}")
raise
# Tokenize input
inputs = _tokenizer(
prompt_text,
return_tensors="pt",
truncation=True,
max_length=512,
padding=True
)
# Generate response
with torch.no_grad():
outputs = _model.generate(
**inputs,
max_new_tokens=max_tokens,
temperature=0.7,
top_p=0.9,
do_sample=True,
pad_token_id=_tokenizer.eos_token_id,
eos_token_id=_tokenizer.eos_token_id,
repetition_penalty=1.1
)
# Decode and return
response = _tokenizer.decode(outputs[0], skip_special_tokens=True)
print("[LLMService] Generation completed.")
return response
class LLMService:
def __init__(self):
self.model_name = _model_name
print("[LLMService] Initializing LLM Service...")
# Initialize tokenizer immediately (lightweight operation)
try:
initialize_tokenizer()
except Exception as e:
print(f"[LLMService] Error loading tokenizer: {e}")
raise
print("[LLMService] Using CPU for inference")
def generate_with_model(self, prompt_text: str, max_tokens: int = 150):
"""Use the text generation function"""
return generate_text(prompt_text, max_tokens)
def _format_frequency_stats(self, stats: Dict) -> str:
"""Format frequency statistics into a readable string."""
move_dist = stats.get('move_distribution', {})
total_moves = sum(move_dist.values())
if total_moves == 0:
return "No game data available yet."
# Calculate percentages and find most common move
frequencies = {move: (count / total_moves) * 100 for move, count in move_dist.items()}
most_common_move = max(frequencies.items(), key=lambda x: x[1])[0]
# Counter moves
counter_moves = {"rock": "paper", "paper": "scissors", "scissors": "rock"}
suggested_counter = counter_moves.get(most_common_move, "rock")
formatted = f"AI Statistics: Rock {frequencies.get('rock', 0):.0f}%, "
formatted += f"Paper {frequencies.get('paper', 0):.0f}%, "
formatted += f"Scissors {frequencies.get('scissors', 0):.0f}%. "
formatted += f"Most frequent: {most_common_move.title()} ({frequencies[most_common_move]:.0f}%). "
formatted += f"Best counter: {suggested_counter.title()}."
return formatted
def _create_analysis_prompt(self, stats: Dict) -> str:
"""Create a focused prompt for game analysis."""
stats_text = self._format_frequency_stats(stats)
prompt = f"""Analyze Rock-Paper-Scissors game data and provide strategy advice:
{stats_text}
Give advice in this exact format:
>>> [Your strategic analysis in one sentence]
Recommendation: [Rock/Paper/Scissors]
Example:
>>> The AI heavily favors Rock (70% of plays), so consistently choose Paper to exploit this pattern.
Recommendation: Paper
Your analysis:
>>>"""
return prompt
def _extract_recommendation(self, full_text: str, original_prompt: str) -> str:
"""Extract the recommendation from the generated text."""
# Remove the original prompt from the response
if original_prompt in full_text:
generated_text = full_text.replace(original_prompt, "").strip()
else:
generated_text = full_text
# Split into lines and clean
lines = [line.strip() for line in generated_text.split('\n') if line.strip()]
# Look for our expected format starting with >>>
result_lines = []
capturing = False
for line in lines:
if line.startswith('>>>'):
capturing = True
result_lines.append(line)
elif capturing and 'recommendation:' in line.lower():
result_lines.append(line)
break
elif capturing and line and not line.startswith(('>>>', 'User:', 'Assistant:')):
# Continue capturing if it looks like part of our response
result_lines.append(line)
if result_lines:
return '\n'.join(result_lines)
# Fallback: search for any recommendation line
for line in lines:
if 'recommendation:' in line.lower():
return f">>> {line}"
# Last resort
return ">>> Try to identify and counter the AI's most common move patterns!\nRecommendation: Rock"
async def generate_response(self, prompt: str, stats: Optional[Dict] = None) -> str:
"""Generate a strategic response based on game statistics."""
print("[LLMService] Starting response generation...")
try:
# Create appropriate prompt based on whether we have stats
if stats and 'move_distribution' in stats and sum(stats['move_distribution'].values()) > 0:
analysis_prompt = self._create_analysis_prompt(stats)
print("[LLMService] Using stats-based analysis prompt")
else:
analysis_prompt = """Give Rock-Paper-Scissors strategy advice for a new game.
Provide advice in this format:
>>> [Your general strategy advice]
Recommendation: [Rock/Paper/Scissors]
Your advice:
>>>"""
print("[LLMService] Using general strategy prompt")
# Use the text generation function
print("[LLMService] Calling text generation method...")
full_response = self.generate_with_model(analysis_prompt, max_tokens=100)
# Extract and clean the response
final_response = self._extract_recommendation(full_response, analysis_prompt)
print("[LLMService] Response generated successfully.")
return final_response
except Exception as e:
print(f"[LLMService] Error in generation: {e}")
# Provide intelligent fallback based on stats
if stats and 'move_distribution' in stats:
moves = stats['move_distribution']
if moves and sum(moves.values()) > 0:
most_common = max(moves.items(), key=lambda x: x[1])
counters = {"rock": "Paper", "paper": "Scissors", "scissors": "Rock"}
counter_move = counters.get(most_common[0], "Rock")
percentage = (most_common[1] / sum(moves.values())) * 100
return f">>> The AI plays {most_common[0].title()} most often ({percentage:.0f}% of the time), so counter with {counter_move}!\nRecommendation: {counter_move}"
# Default fallback
return ">>> Look for patterns in the AI's moves and try to counter them strategically!\nRecommendation: Rock"
async def close(self):
"""Clean up resources."""
global _model
try:
if _model is not None:
del _model
print("[LLMService] Cleanup completed successfully.")
except Exception as e:
print(f"[LLMService] Error during cleanup: {e}")
# Create a singleton instance
llm_service_instance = LLMService()