| | import os |
| | from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
| | from fastapi import FastAPI, HTTPException |
| | from pydantic import BaseModel |
| | import gradio as gr |
| | from sqlalchemy import create_engine, Column, Integer, String, DateTime |
| | from sqlalchemy.ext.declarative import declarative_base |
| | from sqlalchemy.orm import sessionmaker |
| | from datetime import datetime |
| | import threading, uvicorn |
| |
|
| | |
| | |
| | |
| | MODEL_DIR = "." |
| | print("🚀 Carregando modelo CodeGen 350M Mono...") |
| |
|
| | tokenizer = AutoTokenizer.from_pretrained(MODEL_DIR) |
| | model = AutoModelForCausalLM.from_pretrained(MODEL_DIR) |
| | generator = pipeline("text-generation", model=model, tokenizer=tokenizer) |
| | print("✅ Modelo carregado com sucesso!") |
| |
|
| | |
| | |
| | |
| | DATABASE_URL = "sqlite:///./local.db" |
| | engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False}) |
| | Base = declarative_base() |
| | SessionLocal = sessionmaker(bind=engine) |
| | db = SessionLocal() |
| |
|
| | class History(Base): |
| | __tablename__ = "history" |
| | id = Column(Integer, primary_key=True, index=True) |
| | prompt = Column(String) |
| | response = Column(String) |
| | created_at = Column(DateTime, default=datetime.utcnow) |
| |
|
| | Base.metadata.create_all(bind=engine) |
| |
|
| | |
| | |
| | |
| | app = FastAPI(title="💻 CodeGen 350M Mono API") |
| |
|
| | class Prompt(BaseModel): |
| | text: str |
| |
|
| | @app.post("/generate") |
| | def generate_code(data: Prompt): |
| | result = generator(data.text, max_new_tokens=256, num_return_sequences=1)[0]["generated_text"] |
| | db.add(History(prompt=data.text, response=result)) |
| | db.commit() |
| | return {"response": result} |
| |
|
| | @app.post("/upgrade") |
| | def upgrade_model(version: str = "main"): |
| | """ |
| | Atualiza o modelo CodeGen para a versão especificada. |
| | """ |
| | global tokenizer, model, generator |
| | try: |
| | print(f"🚀 Atualizando modelo para a versão: {version}") |
| | |
| | tokenizer = AutoTokenizer.from_pretrained(f"your-username/codegen-350M-mono", revision=version) |
| | model = AutoModelForCausalLM.from_pretrained(f"your-username/codegen-350M-mono", revision=version) |
| | generator = pipeline("text-generation", model=model, tokenizer=tokenizer) |
| | print("✅ Modelo atualizado com sucesso!") |
| | return {"status": "success", "message": f"Modelo atualizado para {version}"} |
| | except Exception as e: |
| | raise HTTPException(status_code=500, detail=f"Erro ao atualizar modelo: {e}") |
| |
|
| | |
| | |
| | |
| | def gradio_generate(prompt): |
| | result = generator(prompt, max_new_tokens=256, num_return_sequences=1)[0]["generated_text"] |
| | db.add(History(prompt=prompt, response=result)) |
| | db.commit() |
| | return result |
| |
|
| | def gradio_upgrade(): |
| | try: |
| | upgrade_model("main") |
| | return "✅ Modelo atualizado com sucesso!" |
| | except Exception as e: |
| | return f"❌ Erro ao atualizar: {e}" |
| |
|
| | css = """ |
| | body { background-color: #0f111a; color: #fff; font-family: 'Arial', sans-serif; } |
| | .gradio-container { background-color: #1c1f33; border-radius: 12px; padding: 20px; } |
| | h1, h2 { color: #10b981; } |
| | textarea { background-color: #12141f; color: #fff; border-radius: 8px; } |
| | button { background-color: #10b981; color: #fff; border-radius: 6px; } |
| | """ |
| |
|
| | demo = gr.Blocks(css=css) |
| |
|
| | with demo: |
| | gr.Markdown("## 💻 CodeGen 350M Mono") |
| | with gr.Row(): |
| | prompt_input = gr.Textbox(lines=5, placeholder="Digite seu código ou pergunta...") |
| | with gr.Column(): |
| | send_btn = gr.Button("Gerar") |
| | upgrade_btn = gr.Button("Atualizar Modelo") |
| | output = gr.Textbox(label="Resposta do Modelo") |
| | |
| | send_btn.click(fn=gradio_generate, inputs=prompt_input, outputs=output) |
| | upgrade_btn.click(fn=gradio_upgrade, inputs=None, outputs=output) |
| |
|
| | |
| | |
| | |
| | if __name__ == "__main__": |
| | def run_gradio(): |
| | demo.launch(server_name="0.0.0.0", server_port=7860, share=False) |
| | threading.Thread(target=run_gradio).start() |
| | uvicorn.run(app, host="0.0.0.0", port=8000) |
| |
|