File size: 3,344 Bytes
058c1a8
2cc725c
 
 
 
 
 
058c1a8
2cc725c
 
058c1a8
2cc725c
 
 
 
1154a37
2cc725c
 
 
 
dae8e37
2cc725c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1154a37
2cc725c
 
 
 
 
 
 
 
 
 
058c1a8
2cc725c
 
 
 
 
 
 
 
 
1154a37
2cc725c
 
 
 
 
 
 
8f3355d
2cc725c
 
 
 
 
 
 
 
 
 
058c1a8
2cc725c
 
 
b379e0d
2cc725c
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import gradio as gr
from flask import Flask, Response, request
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import json
import threading

# Initialize Flask app
flask_app = Flask(__name__)

# Load model and tokenizer
model_id = "Qwen/Qwen2-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cpu")

def generate_response(message, history):
    messages = [
        {"role": "system", "content": "You are Vidyut, an Indian AI created by Rapnss Production Studio India, designed for logical reasoning and problem-solving. Provide clear, step-by-step reasoning for all questions, ensuring accurate and concise answers."}
    ] + history + [{"role": "user", "content": message}]
    
    inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
    streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
    
    generation_kwargs = {
        "inputs": inputs,
        "streamer": streamer,
        "max_new_tokens": 256,
        "do_sample": True,
        "top_p": 0.95,
        "temperature": 0.7,
    }
    
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    chunk_buffer = ""
    min_chunk_length = 10
    punctuation_marks = [".", ",", "!", "?", ";", ":"]
    
    for new_text in streamer:
        chunk_buffer += new_text
        if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length:
            yield chunk_buffer
            chunk_buffer = ""
    
    if chunk_buffer:
        yield chunk_buffer
    
    thread.join()

# Flask API endpoint for external sources
@flask_app.route('/chat', methods=['POST'])
def chat():
    data = request.get_json()
    message = data.get('message', '')
    history = data.get('history', [])
    
    def stream():
        for chunk in generate_response(message, history):
            yield f"data: {json.dumps({'text': chunk})}\n\n"
    
    return Response(stream(), mimetype='text/event-stream')

# Gradio interface
def update_chatbot(message, history):
    for chunk in generate_response(message, history):
        history.append({"role": "user", "content": message})
        history.append({"role": "assistant", "content": chunk})
        yield history, ""
    return history, ""

with gr.Blocks(title="Vidyut Omega Reasoning Chatbot") as demo:
    gr.Markdown("# Vidyut Omega Reasoning Chatbot")
    gr.Markdown("Ask Vidyut reasoning or problem-solving questions, or use the /chat API for external access!")
    chatbot = gr.Chatbot(type="messages")
    msg = gr.Textbox(placeholder="Type your question here...", label="Message")
    submit = gr.Button("Send")
    clear = gr.Button("Clear")
    
    submit.click(
        fn=update_chatbot,
        inputs=[msg, chatbot],
        outputs=[chatbot, msg]
    )
    clear.click(fn=lambda: [], inputs=None, outputs=chatbot)

# Run Flask in a separate thread
def run_flask():
    flask_app.run(host="0.0.0.0", port=8000)

flask_thread = threading.Thread(target=run_flask)
flask_thread.daemon = True
flask_thread.start()

# Launch Gradio
demo.launch(server_name="0.0.0.0", server_port=7860)