File size: 3,344 Bytes
058c1a8 2cc725c 058c1a8 2cc725c 058c1a8 2cc725c 1154a37 2cc725c dae8e37 2cc725c 1154a37 2cc725c 058c1a8 2cc725c 1154a37 2cc725c 8f3355d 2cc725c 058c1a8 2cc725c b379e0d 2cc725c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import gradio as gr
from flask import Flask, Response, request
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import json
import threading
# Initialize Flask app
flask_app = Flask(__name__)
# Load model and tokenizer
model_id = "Qwen/Qwen2-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cpu")
def generate_response(message, history):
messages = [
{"role": "system", "content": "You are Vidyut, an Indian AI created by Rapnss Production Studio India, designed for logical reasoning and problem-solving. Provide clear, step-by-step reasoning for all questions, ensuring accurate and concise answers."}
] + history + [{"role": "user", "content": message}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"inputs": inputs,
"streamer": streamer,
"max_new_tokens": 256,
"do_sample": True,
"top_p": 0.95,
"temperature": 0.7,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
chunk_buffer = ""
min_chunk_length = 10
punctuation_marks = [".", ",", "!", "?", ";", ":"]
for new_text in streamer:
chunk_buffer += new_text
if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length:
yield chunk_buffer
chunk_buffer = ""
if chunk_buffer:
yield chunk_buffer
thread.join()
# Flask API endpoint for external sources
@flask_app.route('/chat', methods=['POST'])
def chat():
data = request.get_json()
message = data.get('message', '')
history = data.get('history', [])
def stream():
for chunk in generate_response(message, history):
yield f"data: {json.dumps({'text': chunk})}\n\n"
return Response(stream(), mimetype='text/event-stream')
# Gradio interface
def update_chatbot(message, history):
for chunk in generate_response(message, history):
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": chunk})
yield history, ""
return history, ""
with gr.Blocks(title="Vidyut Omega Reasoning Chatbot") as demo:
gr.Markdown("# Vidyut Omega Reasoning Chatbot")
gr.Markdown("Ask Vidyut reasoning or problem-solving questions, or use the /chat API for external access!")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox(placeholder="Type your question here...", label="Message")
submit = gr.Button("Send")
clear = gr.Button("Clear")
submit.click(
fn=update_chatbot,
inputs=[msg, chatbot],
outputs=[chatbot, msg]
)
clear.click(fn=lambda: [], inputs=None, outputs=chatbot)
# Run Flask in a separate thread
def run_flask():
flask_app.run(host="0.0.0.0", port=8000)
flask_thread = threading.Thread(target=run_flask)
flask_thread.daemon = True
flask_thread.start()
# Launch Gradio
demo.launch(server_name="0.0.0.0", server_port=7860) |