ChatSPE / app.py
saifkhanengr's picture
Updated
7544a45
import os
import torch
import gradio as gr
import tiktoken
from model import Config, DeepSeek_V3_Model, generate_text, clean_response
# UI Metadata
TITLE = "ChatSPE - Petroleum Engineering Assistant"
DESCRIPTION = (
"**Note**: This model is for educational purposes only. "
"It was pretrained on only two books and fine-tuned on just ten samples, "
"so response quality may be **extremely limited.**"
)
EXAMPLES = [
["What is porosity?"],
["Explain water saturation."],
["What is permeability in petroleum engineering?"],
]
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Tokenizer (tiktoken GPT-2)
tokenizer = tiktoken.get_encoding("gpt2")
# Load Model
config = Config()
ChatSPE = DeepSeek_V3_Model(config)
model_path = os.path.join(".", "chatspe.pt")
ChatSPE.load_state_dict(torch.load(model_path, map_location=device))
ChatSPE.to(device)
ChatSPE.eval()
# Chat Function
def chat(user_message, history):
try:
prompt = f"Prompt: {user_message}\nResponse:"
# Generate model output
generated = generate_text(
model=ChatSPE,
tokenizer=tokenizer,
prompt=prompt,
max_length=80,
temperature=0.7,
top_k=10,
eos_id=config.pad_token_id,
device=device
)
# Ensure string and clean
if isinstance(generated, bytes):
generated = generated.decode("utf-8", errors="ignore")
# Flatten any nested list or tuple from the model output
if isinstance(generated, (list, tuple)):
generated = " ".join(map(str, generated))
reply = clean_response(str(generated))
return reply
except Exception as e:
print(f"Error in chat function: {e}")
return "Sorry, I encountered an error generating a response. Please try again."
# Gradio Chat UI - disable caching
demo = gr.ChatInterface(
fn=chat,
title=TITLE,
description=DESCRIPTION,
examples=EXAMPLES,
cache_examples=False # Disable caching
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)