Server-A / app.py
Invescoz's picture
Update app.py
2cc725c verified
import gradio as gr
from flask import Flask, Response, request
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
import torch
from threading import Thread
import json
import threading
# Initialize Flask app
flask_app = Flask(__name__)
# Load model and tokenizer
model_id = "Qwen/Qwen2-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="cpu")
def generate_response(message, history):
messages = [
{"role": "system", "content": "You are Vidyut, an Indian AI created by Rapnss Production Studio India, designed for logical reasoning and problem-solving. Provide clear, step-by-step reasoning for all questions, ensuring accurate and concise answers."}
] + history + [{"role": "user", "content": message}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = {
"inputs": inputs,
"streamer": streamer,
"max_new_tokens": 256,
"do_sample": True,
"top_p": 0.95,
"temperature": 0.7,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
chunk_buffer = ""
min_chunk_length = 10
punctuation_marks = [".", ",", "!", "?", ";", ":"]
for new_text in streamer:
chunk_buffer += new_text
if any(p in chunk_buffer for p in punctuation_marks) or len(chunk_buffer) >= min_chunk_length:
yield chunk_buffer
chunk_buffer = ""
if chunk_buffer:
yield chunk_buffer
thread.join()
# Flask API endpoint for external sources
@flask_app.route('/chat', methods=['POST'])
def chat():
data = request.get_json()
message = data.get('message', '')
history = data.get('history', [])
def stream():
for chunk in generate_response(message, history):
yield f"data: {json.dumps({'text': chunk})}\n\n"
return Response(stream(), mimetype='text/event-stream')
# Gradio interface
def update_chatbot(message, history):
for chunk in generate_response(message, history):
history.append({"role": "user", "content": message})
history.append({"role": "assistant", "content": chunk})
yield history, ""
return history, ""
with gr.Blocks(title="Vidyut Omega Reasoning Chatbot") as demo:
gr.Markdown("# Vidyut Omega Reasoning Chatbot")
gr.Markdown("Ask Vidyut reasoning or problem-solving questions, or use the /chat API for external access!")
chatbot = gr.Chatbot(type="messages")
msg = gr.Textbox(placeholder="Type your question here...", label="Message")
submit = gr.Button("Send")
clear = gr.Button("Clear")
submit.click(
fn=update_chatbot,
inputs=[msg, chatbot],
outputs=[chatbot, msg]
)
clear.click(fn=lambda: [], inputs=None, outputs=chatbot)
# Run Flask in a separate thread
def run_flask():
flask_app.run(host="0.0.0.0", port=8000)
flask_thread = threading.Thread(target=run_flask)
flask_thread.daemon = True
flask_thread.start()
# Launch Gradio
demo.launch(server_name="0.0.0.0", server_port=7860)