|
|
import gradio as gr |
|
|
from flask import Flask, Response, request |
|
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
import torch |
|
|
from threading import Thread |
|
|
import json |
|
|
import threading |
|
|
|
|
|
|
|
|
flask_app = Flask(__name__) |
|
|
|
|
|
|
|
|
model_id = "Qwen/Qwen2-1.5B-Instruct" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cpu") |
|
|
|
|
|
def generate_response(message, history): |
|
|
messages = [ |
|
|
{"role": "system", "content": "You are Vidyut, an Indian AI created by Rapnss Production Studio India, designed for logical reasoning and problem-solving. Provide clear, step-by-step reasoning for all questions, ensuring accurate and concise answers."} |
|
|
] + history + [{"role": "user", "content": message}] |
|
|
|
|
|
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device) |
|
|
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
|
|
generation_kwargs = { |
|
|
"inputs": inputs, |
|
|
"streamer": streamer, |
|
|
"max_new_tokens": 256, |
|
|
"do_sample": True, |
|
|
"top_p": 0.95, |
|
|
"temperature": 0.7, |
|
|
} |
|
|
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
|
thread.start() |
|
|
|
|
|
chunk_buffer = "" |
|
|
min_chunk_length = 10 |
|
|
punctuation_marks = [".", ",", "!", "?", ";", ":"] |
|
|
|
|
|
for new_text in streamer: |
|
|
chunk_buffer += new_text |
|
|
if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length: |
|
|
yield chunk_buffer |
|
|
chunk_buffer = "" |
|
|
|
|
|
if chunk_buffer: |
|
|
yield chunk_buffer |
|
|
|
|
|
thread.join() |
|
|
|
|
|
|
|
|
@flask_app.route('/chat', methods=['POST']) |
|
|
def chat(): |
|
|
data = request.get_json() |
|
|
message = data.get('message', '') |
|
|
history = data.get('history', []) |
|
|
|
|
|
def stream(): |
|
|
for chunk in generate_response(message, history): |
|
|
yield f"data: {json.dumps({'text': chunk})}\n\n" |
|
|
|
|
|
return Response(stream(), mimetype='text/event-stream') |
|
|
|
|
|
|
|
|
def update_chatbot(message, history): |
|
|
for chunk in generate_response(message, history): |
|
|
history.append({"role": "user", "content": message}) |
|
|
history.append({"role": "assistant", "content": chunk}) |
|
|
yield history, "" |
|
|
return history, "" |
|
|
|
|
|
with gr.Blocks(title="Vidyut Omega Reasoning Chatbot") as demo: |
|
|
gr.Markdown("# Vidyut Omega Reasoning Chatbot") |
|
|
gr.Markdown("Ask Vidyut reasoning or problem-solving questions, or use the /chat API for external access!") |
|
|
chatbot = gr.Chatbot(type="messages") |
|
|
msg = gr.Textbox(placeholder="Type your question here...", label="Message") |
|
|
submit = gr.Button("Send") |
|
|
clear = gr.Button("Clear") |
|
|
|
|
|
submit.click( |
|
|
fn=update_chatbot, |
|
|
inputs=[msg, chatbot], |
|
|
outputs=[chatbot, msg] |
|
|
) |
|
|
clear.click(fn=lambda: [], inputs=None, outputs=chatbot) |
|
|
|
|
|
|
|
|
def run_flask(): |
|
|
flask_app.run(host="0.0.0.0", port=8000) |
|
|
|
|
|
flask_thread = threading.Thread(target=run_flask) |
|
|
flask_thread.daemon = True |
|
|
flask_thread.start() |
|
|
|
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |