|
|
|
|
|
"""CUDA-optimized chat interface for Ursa Minor Smashed model"""
|
|
|
|
|
|
import torch
|
|
|
from inference_cuda import generate_direct, load_model_direct
|
|
|
|
|
|
def main():
|
|
|
print("Ursa Minor Smashed Chat (CUDA)")
|
|
|
print("Type 'quit' to exit, 'reset' to clear context")
|
|
|
print("-" * 50)
|
|
|
|
|
|
if not torch.cuda.is_available():
|
|
|
print("ERROR: CUDA is not available. Use chat_cpu.py for CPU inference.")
|
|
|
return
|
|
|
|
|
|
|
|
|
print("Loading model on CUDA...")
|
|
|
model = load_model_direct("model_optimized.pt")
|
|
|
print("Model loaded! Ready to chat.\n")
|
|
|
|
|
|
context = ""
|
|
|
max_context_length = 800
|
|
|
|
|
|
while True:
|
|
|
user_input = input("You: ").strip()
|
|
|
|
|
|
if user_input.lower() == 'quit':
|
|
|
print("Goodbye!")
|
|
|
break
|
|
|
elif user_input.lower() == 'reset':
|
|
|
context = ""
|
|
|
print("Context cleared!")
|
|
|
continue
|
|
|
elif user_input == "":
|
|
|
continue
|
|
|
|
|
|
|
|
|
if context:
|
|
|
context += f"\nHuman: {user_input}\nAssistant:"
|
|
|
else:
|
|
|
context = f"Human: {user_input}\nAssistant:"
|
|
|
|
|
|
|
|
|
if len(context.split()) > max_context_length:
|
|
|
|
|
|
words = context.split()
|
|
|
context = " ".join(words[-max_context_length:])
|
|
|
|
|
|
|
|
|
try:
|
|
|
full_response = generate_direct(
|
|
|
model,
|
|
|
context,
|
|
|
max_new_tokens=100,
|
|
|
temperature=0.8,
|
|
|
top_p=0.9,
|
|
|
top_k=50,
|
|
|
repetition_penalty=1.1
|
|
|
)
|
|
|
|
|
|
|
|
|
response = full_response[len(context):].strip()
|
|
|
|
|
|
|
|
|
if "Human:" in response:
|
|
|
response = response.split("Human:")[0].strip()
|
|
|
|
|
|
print(f"Assistant: {response}")
|
|
|
|
|
|
|
|
|
context = full_response
|
|
|
|
|
|
except Exception as e:
|
|
|
print(f"Error generating response: {e}")
|
|
|
print("Try typing 'reset' to clear context and continue.")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
main() |